[SPARK-14628][CORE] Simplify task metrics by always tracking read/write metrics

## What changes were proposed in this pull request?

Part of the reason why TaskMetrics and its callers are complicated are due to 
the optional metrics we collect, including input, output, shuffle read, and 
shuffle write. I think we can always track them and just assign 0 as the 
initial values. It is usually very obvious whether a task is supposed to read 
any data or not. By always tracking them, we can remove a lot of map, foreach, 
flatMap, getOrElse(0L) calls throughout Spark.

This patch also changes a few behaviors.

1. Removed the distinction of data read/write methods (e.g. Hadoop, Memory, 
Network, etc).
2. Accumulate all data reads and writes, rather than only the first method. 
(Fixes SPARK-5225)

## How was this patch tested?

existing tests.

This is bases on https://github.com/apache/spark/pull/12388, with more test 
fixes.

Author: Reynold Xin <[email protected]>
Author: Wenchen Fan <[email protected]>

Closes #12417 from cloud-fan/metrics-refactor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8028a288
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8028a288
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8028a288

Branch: refs/heads/master
Commit: 8028a28885dbd90f20e38922240618fc310a0a65
Parents: 90b46e0
Author: Reynold Xin <[email protected]>
Authored: Fri Apr 15 15:39:39 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Fri Apr 15 15:39:39 2016 -0700

----------------------------------------------------------------------
 .../sort/BypassMergeSortShuffleWriter.java      |   2 +-
 .../spark/shuffle/sort/UnsafeShuffleWriter.java |   2 +-
 .../unsafe/sort/UnsafeExternalSorter.java       |   2 +-
 .../org/apache/spark/InternalAccumulator.scala  |   6 -
 .../apache/spark/executor/InputMetrics.scala    |  27 +---
 .../apache/spark/executor/OutputMetrics.scala   |  15 +-
 .../spark/executor/ShuffleReadMetrics.scala     |   7 +-
 .../spark/executor/ShuffleWriteMetrics.scala    |   7 +-
 .../org/apache/spark/executor/TaskMetrics.scala | 122 ++--------------
 .../scala/org/apache/spark/rdd/HadoopRDD.scala  |   2 +-
 .../org/apache/spark/rdd/NewHadoopRDD.scala     |   2 +-
 .../org/apache/spark/rdd/PairRDDFunctions.scala |   2 +-
 .../main/scala/org/apache/spark/rdd/RDD.scala   |   2 +-
 .../spark/scheduler/StatsReportListener.scala   |  46 +++---
 .../spark/shuffle/BlockStoreShuffleReader.scala |   2 +-
 .../spark/shuffle/hash/HashShuffleWriter.scala  |   2 +-
 .../spark/shuffle/sort/SortShuffleWriter.scala  |   2 +-
 .../spark/status/api/v1/AllStagesResource.scala | 100 ++++++++-----
 .../storage/ShuffleBlockFetcherIterator.scala   |   2 +-
 .../org/apache/spark/ui/exec/ExecutorsTab.scala |  33 ++---
 .../spark/ui/jobs/JobProgressListener.scala     |  32 ++---
 .../org/apache/spark/ui/jobs/StagePage.scala    |  37 +++--
 .../org/apache/spark/util/JsonProtocol.scala    |  66 ++++-----
 .../spark/util/collection/ExternalSorter.scala  |   2 +-
 .../shuffle/sort/UnsafeShuffleWriterSuite.java  |  12 +-
 .../apache/spark/InternalAccumulatorSuite.scala |   5 -
 .../scala/org/apache/spark/ShuffleSuite.scala   |  12 +-
 .../spark/executor/TaskMetricsSuite.scala       | 122 ++--------------
 .../spark/metrics/InputOutputMetricsSuite.scala |  69 ++-------
 .../spark/scheduler/SparkListenerSuite.scala    |  15 +-
 .../BypassMergeSortShuffleWriterSuite.scala     |   4 +-
 .../ui/jobs/JobProgressListenerSuite.scala      |  31 ++--
 .../apache/spark/util/JsonProtocolSuite.scala   | 143 +++++++------------
 project/MimaExcludes.scala                      |   5 +-
 34 files changed, 330 insertions(+), 610 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
 
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 7a60c3e..0e9defe 100644
--- 
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ 
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -114,7 +114,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
     this.shuffleId = dep.shuffleId();
     this.partitioner = dep.partitioner();
     this.numPartitions = partitioner.numPartitions();
-    this.writeMetrics = 
taskContext.taskMetrics().registerShuffleWriteMetrics();
+    this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
     this.serializer = dep.serializer();
     this.shuffleBlockResolver = shuffleBlockResolver;
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 0c5fb88..daa63d4 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -118,7 +118,7 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
     this.shuffleId = dep.shuffleId();
     this.serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
-    this.writeMetrics = 
taskContext.taskMetrics().registerShuffleWriteMetrics();
+    this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
     this.taskContext = taskContext;
     this.sparkConf = sparkConf;
     this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", 
true);

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index ef79b49..3e32dd9 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -129,7 +129,7 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility for 
units
     // this.fileBufferSizeBytes = (int) 
conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.fileBufferSizeBytes = 32 * 1024;
-    this.writeMetrics = 
taskContext.taskMetrics().registerShuffleWriteMetrics();
+    this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
 
     if (existingInMemorySorter == null) {
       this.inMemSorter = new UnsafeInMemorySorter(

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala 
b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
index 0dd4ec6..714c873 100644
--- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
+++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
@@ -68,14 +68,12 @@ private[spark] object InternalAccumulator {
 
   // Names of output metrics
   object output {
-    val WRITE_METHOD = OUTPUT_METRICS_PREFIX + "writeMethod"
     val BYTES_WRITTEN = OUTPUT_METRICS_PREFIX + "bytesWritten"
     val RECORDS_WRITTEN = OUTPUT_METRICS_PREFIX + "recordsWritten"
   }
 
   // Names of input metrics
   object input {
-    val READ_METHOD = INPUT_METRICS_PREFIX + "readMethod"
     val BYTES_READ = INPUT_METRICS_PREFIX + "bytesRead"
     val RECORDS_READ = INPUT_METRICS_PREFIX + "recordsRead"
   }
@@ -110,8 +108,6 @@ private[spark] object InternalAccumulator {
       case UPDATED_BLOCK_STATUSES => UpdatedBlockStatusesAccumulatorParam
       case shuffleRead.LOCAL_BLOCKS_FETCHED => IntAccumulatorParam
       case shuffleRead.REMOTE_BLOCKS_FETCHED => IntAccumulatorParam
-      case input.READ_METHOD => StringAccumulatorParam
-      case output.WRITE_METHOD => StringAccumulatorParam
       case _ => LongAccumulatorParam
     }
   }
@@ -165,7 +161,6 @@ private[spark] object InternalAccumulator {
    */
   def createInputAccums(): Seq[Accumulator[_]] = {
     Seq[String](
-      input.READ_METHOD,
       input.BYTES_READ,
       input.RECORDS_READ).map(create)
   }
@@ -175,7 +170,6 @@ private[spark] object InternalAccumulator {
    */
   def createOutputAccums(): Seq[Accumulator[_]] = {
     Seq[String](
-      output.WRITE_METHOD,
       output.BYTES_WRITTEN,
       output.RECORDS_WRITTEN).map(create)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
index 83e11c5..2181bde 100644
--- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
@@ -39,31 +39,13 @@ object DataReadMethod extends Enumeration with Serializable 
{
  * A collection of accumulators that represents metrics about reading data 
from external systems.
  */
 @DeveloperApi
-class InputMetrics private (
-    _bytesRead: Accumulator[Long],
-    _recordsRead: Accumulator[Long],
-    _readMethod: Accumulator[String])
+class InputMetrics private (_bytesRead: Accumulator[Long], _recordsRead: 
Accumulator[Long])
   extends Serializable {
 
   private[executor] def this(accumMap: Map[String, Accumulator[_]]) {
     this(
       TaskMetrics.getAccum[Long](accumMap, 
InternalAccumulator.input.BYTES_READ),
-      TaskMetrics.getAccum[Long](accumMap, 
InternalAccumulator.input.RECORDS_READ),
-      TaskMetrics.getAccum[String](accumMap, 
InternalAccumulator.input.READ_METHOD))
-  }
-
-  /**
-   * Create a new [[InputMetrics]] that is not associated with any particular 
task.
-   *
-   * This mainly exists because of SPARK-5225, where we are forced to use a 
dummy [[InputMetrics]]
-   * because we want to ignore metrics from a second read method. In the 
future, we should revisit
-   * whether this is needed.
-   *
-   * A better alternative is [[TaskMetrics.registerInputMetrics]].
-   */
-  private[executor] def this() {
-    this(InternalAccumulator.createInputAccums()
-      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]])
+      TaskMetrics.getAccum[Long](accumMap, 
InternalAccumulator.input.RECORDS_READ))
   }
 
   /**
@@ -77,13 +59,12 @@ class InputMetrics private (
   def recordsRead: Long = _recordsRead.localValue
 
   /**
-   * The source from which this task reads its input.
+   * Returns true if this metrics has been updated before.
    */
-  def readMethod: DataReadMethod.Value = 
DataReadMethod.withName(_readMethod.localValue)
+  def isUpdated: Boolean = (bytesRead | recordsRead) != 0
 
   private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v)
   private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
   private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v)
-  private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = 
_readMethod.setValue(v.toString)
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
index 93f9538..7f20f6b 100644
--- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
@@ -38,17 +38,13 @@ object DataWriteMethod extends Enumeration with 
Serializable {
  * A collection of accumulators that represents metrics about writing data to 
external systems.
  */
 @DeveloperApi
-class OutputMetrics private (
-    _bytesWritten: Accumulator[Long],
-    _recordsWritten: Accumulator[Long],
-    _writeMethod: Accumulator[String])
+class OutputMetrics private (_bytesWritten: Accumulator[Long], 
_recordsWritten: Accumulator[Long])
   extends Serializable {
 
   private[executor] def this(accumMap: Map[String, Accumulator[_]]) {
     this(
       TaskMetrics.getAccum[Long](accumMap, 
InternalAccumulator.output.BYTES_WRITTEN),
-      TaskMetrics.getAccum[Long](accumMap, 
InternalAccumulator.output.RECORDS_WRITTEN),
-      TaskMetrics.getAccum[String](accumMap, 
InternalAccumulator.output.WRITE_METHOD))
+      TaskMetrics.getAccum[Long](accumMap, 
InternalAccumulator.output.RECORDS_WRITTEN))
   }
 
   /**
@@ -62,13 +58,10 @@ class OutputMetrics private (
   def recordsWritten: Long = _recordsWritten.localValue
 
   /**
-   * The source to which this task writes its output.
+   * Returns true if this metrics has been updated before.
    */
-  def writeMethod: DataWriteMethod.Value = 
DataWriteMethod.withName(_writeMethod.localValue)
+  def isUpdated: Boolean = (bytesWritten | recordsWritten) != 0
 
   private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v)
   private[spark] def setRecordsWritten(v: Long): Unit = 
_recordsWritten.setValue(v)
-  private[spark] def setWriteMethod(v: DataWriteMethod.Value): Unit =
-    _writeMethod.setValue(v.toString)
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
index 71a2477..9c78995 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
@@ -53,7 +53,7 @@ class ShuffleReadMetrics private (
    * many places only to merge their values together later. In the future, we 
should revisit
    * whether this is needed.
    *
-   * A better alternative is [[TaskMetrics.registerTempShuffleReadMetrics]] 
followed by
+   * A better alternative is [[TaskMetrics.createTempShuffleReadMetrics]] 
followed by
    * [[TaskMetrics.mergeShuffleReadMetrics]].
    */
   private[spark] def this() {
@@ -102,6 +102,11 @@ class ShuffleReadMetrics private (
    */
   def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched
 
+  /**
+   * Returns true if this metrics has been updated before.
+   */
+  def isUpdated: Boolean = (totalBytesRead | totalBlocksFetched | recordsRead 
| fetchWaitTime) != 0
+
   private[spark] def incRemoteBlocksFetched(v: Int): Unit = 
_remoteBlocksFetched.add(v)
   private[spark] def incLocalBlocksFetched(v: Int): Unit = 
_localBlocksFetched.add(v)
   private[spark] def incRemoteBytesRead(v: Long): Unit = 
_remoteBytesRead.add(v)

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
index c7aaabb..cf570e1 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
@@ -47,7 +47,7 @@ class ShuffleWriteMetrics private (
    * many places only to merge their values together later. In the future, we 
should revisit
    * whether this is needed.
    *
-   * A better alternative is [[TaskMetrics.registerShuffleWriteMetrics]].
+   * A better alternative is [[TaskMetrics.shuffleWriteMetrics]].
    */
   private[spark] def this() {
     this(InternalAccumulator.createShuffleWriteAccums().map { a => 
(a.name.get, a) }.toMap)
@@ -68,6 +68,11 @@ class ShuffleWriteMetrics private (
    */
   def writeTime: Long = _writeTime.localValue
 
+  /**
+   * Returns true if this metrics has been updated before.
+   */
+  def isUpdated: Boolean = (writeTime | recordsWritten | bytesWritten) != 0
+
   private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
   private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
   private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v)

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index bda2a91..0198364 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -91,6 +91,14 @@ class TaskMetrics private[spark] (initialAccums: 
Seq[Accumulator[_]]) extends Se
   private val _updatedBlockStatuses =
     TaskMetrics.getAccum[Seq[(BlockId, BlockStatus)]](initialAccumsMap, 
UPDATED_BLOCK_STATUSES)
 
+  private val _inputMetrics = new InputMetrics(initialAccumsMap)
+
+  private val _outputMetrics = new OutputMetrics(initialAccumsMap)
+
+  private val _shuffleReadMetrics = new ShuffleReadMetrics(initialAccumsMap)
+
+  private val _shuffleWriteMetrics = new ShuffleWriteMetrics(initialAccumsMap)
+
   /**
    * Time taken on the executor to deserialize this task.
    */
@@ -163,83 +171,23 @@ class TaskMetrics private[spark] (initialAccums: 
Seq[Accumulator[_]]) extends Se
     TaskMetrics.getAccum[Long](initialAccumsMap, name)
   }
 
-
-  /* ========================== *
-   |        INPUT METRICS       |
-   * ========================== */
-
-  private var _inputMetrics: Option[InputMetrics] = None
-
   /**
    * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] 
or from persisted
    * data, defined only in tasks with input.
    */
-  def inputMetrics: Option[InputMetrics] = _inputMetrics
-
-  /**
-   * Get or create a new [[InputMetrics]] associated with this task.
-   */
-  private[spark] def registerInputMetrics(readMethod: DataReadMethod.Value): 
InputMetrics = {
-    synchronized {
-      val metrics = _inputMetrics.getOrElse {
-        val metrics = new InputMetrics(initialAccumsMap)
-        metrics.setReadMethod(readMethod)
-        _inputMetrics = Some(metrics)
-        metrics
-      }
-      // If there already exists an InputMetric with the same read method, we 
can just return
-      // that one. Otherwise, if the read method is different from the one 
previously seen by
-      // this task, we return a new dummy one to avoid clobbering the values 
of the old metrics.
-      // In the future we should try to store input metrics from all different 
read methods at
-      // the same time (SPARK-5225).
-      if (metrics.readMethod == readMethod) {
-        metrics
-      } else {
-        val m = new InputMetrics
-        m.setReadMethod(readMethod)
-        m
-      }
-    }
-  }
-
-
-  /* ============================ *
-   |        OUTPUT METRICS        |
-   * ============================ */
-
-  private var _outputMetrics: Option[OutputMetrics] = None
+  def inputMetrics: InputMetrics = _inputMetrics
 
   /**
    * Metrics related to writing data externally (e.g. to a distributed 
filesystem),
    * defined only in tasks with output.
    */
-  def outputMetrics: Option[OutputMetrics] = _outputMetrics
-
-  /**
-   * Get or create a new [[OutputMetrics]] associated with this task.
-   */
-  private[spark] def registerOutputMetrics(
-      writeMethod: DataWriteMethod.Value): OutputMetrics = synchronized {
-    _outputMetrics.getOrElse {
-      val metrics = new OutputMetrics(initialAccumsMap)
-      metrics.setWriteMethod(writeMethod)
-      _outputMetrics = Some(metrics)
-      metrics
-    }
-  }
-
-
-  /* ================================== *
-   |        SHUFFLE READ METRICS        |
-   * ================================== */
-
-  private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None
+  def outputMetrics: OutputMetrics = _outputMetrics
 
   /**
    * Metrics related to shuffle read aggregated across all shuffle 
dependencies.
    * This is defined only if there are shuffle dependencies in this task.
    */
-  def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics
+  def shuffleReadMetrics: ShuffleReadMetrics = _shuffleReadMetrics
 
   /**
    * Temporary list of [[ShuffleReadMetrics]], one per shuffle dependency.
@@ -257,7 +205,7 @@ class TaskMetrics private[spark] (initialAccums: 
Seq[Accumulator[_]]) extends Se
    * merges the temporary values synchronously. Otherwise, all temporary data 
collected will
    * be lost.
    */
-  private[spark] def registerTempShuffleReadMetrics(): ShuffleReadMetrics = 
synchronized {
+  private[spark] def createTempShuffleReadMetrics(): ShuffleReadMetrics = 
synchronized {
     val readMetrics = new ShuffleReadMetrics
     tempShuffleReadMetrics += readMetrics
     readMetrics
@@ -269,34 +217,14 @@ class TaskMetrics private[spark] (initialAccums: 
Seq[Accumulator[_]]) extends Se
    */
   private[spark] def mergeShuffleReadMetrics(): Unit = synchronized {
     if (tempShuffleReadMetrics.nonEmpty) {
-      val metrics = new ShuffleReadMetrics(initialAccumsMap)
-      metrics.setMergeValues(tempShuffleReadMetrics)
-      _shuffleReadMetrics = Some(metrics)
+      _shuffleReadMetrics.setMergeValues(tempShuffleReadMetrics)
     }
   }
 
-  /* =================================== *
-   |        SHUFFLE WRITE METRICS        |
-   * =================================== */
-
-  private var _shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
-
   /**
    * Metrics related to shuffle write, defined only in shuffle map stages.
    */
-  def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics
-
-  /**
-   * Get or create a new [[ShuffleWriteMetrics]] associated with this task.
-   */
-  private[spark] def registerShuffleWriteMetrics(): ShuffleWriteMetrics = 
synchronized {
-    _shuffleWriteMetrics.getOrElse {
-      val metrics = new ShuffleWriteMetrics(initialAccumsMap)
-      _shuffleWriteMetrics = Some(metrics)
-      metrics
-    }
-  }
-
+  def shuffleWriteMetrics: ShuffleWriteMetrics = _shuffleWriteMetrics
 
   /* ========================== *
    |        OTHER THINGS        |
@@ -316,28 +244,6 @@ class TaskMetrics private[spark] (initialAccums: 
Seq[Accumulator[_]]) extends Se
   def accumulatorUpdates(): Seq[AccumulableInfo] = {
     accums.map { a => a.toInfo(Some(a.localValue), None) }
   }
-
-  // If we are reconstructing this TaskMetrics on the driver, some metrics may 
already be set.
-  // If so, initialize all relevant metrics classes so listeners can access 
them downstream.
-  {
-    var (hasShuffleRead, hasShuffleWrite, hasInput, hasOutput) = (false, 
false, false, false)
-    initialAccums
-      .filter { a => a.localValue != a.zero }
-      .foreach { a =>
-        a.name.get match {
-          case sr if sr.startsWith(SHUFFLE_READ_METRICS_PREFIX) => 
hasShuffleRead = true
-          case sw if sw.startsWith(SHUFFLE_WRITE_METRICS_PREFIX) => 
hasShuffleWrite = true
-          case in if in.startsWith(INPUT_METRICS_PREFIX) => hasInput = true
-          case out if out.startsWith(OUTPUT_METRICS_PREFIX) => hasOutput = true
-          case _ =>
-        }
-      }
-    if (hasShuffleRead) { _shuffleReadMetrics = Some(new 
ShuffleReadMetrics(initialAccumsMap)) }
-    if (hasShuffleWrite) { _shuffleWriteMetrics = Some(new 
ShuffleWriteMetrics(initialAccumsMap)) }
-    if (hasInput) { _inputMetrics = Some(new InputMetrics(initialAccumsMap)) }
-    if (hasOutput) { _outputMetrics = Some(new 
OutputMetrics(initialAccumsMap)) }
-  }
-
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 35d190b..6b1e155 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -213,7 +213,7 @@ class HadoopRDD[K, V](
       logInfo("Input split: " + split.inputSplit)
       val jobConf = getJobConf()
 
-      val inputMetrics = 
context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
+      val inputMetrics = context.taskMetrics().inputMetrics
       val existingBytesRead = inputMetrics.bytesRead
 
       // Sets the thread local variable for the file's name

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 3ccd616..a71c191 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -130,7 +130,7 @@ class NewHadoopRDD[K, V](
       logInfo("Input split: " + split.serializableHadoopSplit)
       val conf = getConf
 
-      val inputMetrics = 
context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
+      val inputMetrics = context.taskMetrics().inputMetrics
       val existingBytesRead = inputMetrics.bytesRead
 
       // Find a function that will return the FileSystem bytes read by this 
thread. Do this before

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala 
b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 085829a..7936d8e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -1218,7 +1218,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
       context: TaskContext): Option[(OutputMetrics, () => Long)] = {
     val bytesWrittenCallback = 
SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
     bytesWrittenCallback.map { b =>
-      (context.taskMetrics().registerOutputMetrics(DataWriteMethod.Hadoop), b)
+      (context.taskMetrics().outputMetrics, b)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 36ff3bc..f6e0148 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -332,7 +332,7 @@ abstract class RDD[T: ClassTag](
     }) match {
       case Left(blockResult) =>
         if (readCachedBlock) {
-          val existingMetrics = 
context.taskMetrics().registerInputMetrics(blockResult.readMethod)
+          val existingMetrics = context.taskMetrics().inputMetrics
           existingMetrics.incBytesRead(blockResult.bytes)
           new InterruptibleIterator[T](context, 
blockResult.data.asInstanceOf[Iterator[T]]) {
             override def next(): T = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala 
b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
index 309f4b8..3c8cab7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
@@ -47,19 +47,19 @@ class StatsReportListener extends SparkListener with 
Logging {
   override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
     implicit val sc = stageCompleted
     this.logInfo(s"Finished stage: 
${getStatusDetail(stageCompleted.stageInfo)}")
-    showMillisDistribution("task runtime:", (info, _) => Some(info.duration), 
taskInfoMetrics)
+    showMillisDistribution("task runtime:", (info, _) => info.duration, 
taskInfoMetrics)
 
     // Shuffle write
     showBytesDistribution("shuffle bytes written:",
-      (_, metric) => metric.shuffleWriteMetrics.map(_.bytesWritten), 
taskInfoMetrics)
+      (_, metric) => metric.shuffleWriteMetrics.bytesWritten, taskInfoMetrics)
 
     // Fetch & I/O
     showMillisDistribution("fetch wait time:",
-      (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), 
taskInfoMetrics)
+      (_, metric) => metric.shuffleReadMetrics.fetchWaitTime, taskInfoMetrics)
     showBytesDistribution("remote bytes read:",
-      (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), 
taskInfoMetrics)
+      (_, metric) => metric.shuffleReadMetrics.remoteBytesRead, 
taskInfoMetrics)
     showBytesDistribution("task result size:",
-      (_, metric) => Some(metric.resultSize), taskInfoMetrics)
+      (_, metric) => metric.resultSize, taskInfoMetrics)
 
     // Runtime breakdown
     val runtimePcts = taskInfoMetrics.map { case (info, metrics) =>
@@ -95,17 +95,17 @@ private[spark] object StatsReportListener extends Logging {
 
   def extractDoubleDistribution(
     taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
-    getMetric: (TaskInfo, TaskMetrics) => Option[Double]): 
Option[Distribution] = {
-    Distribution(taskInfoMetrics.flatMap { case (info, metric) => 
getMetric(info, metric) })
+    getMetric: (TaskInfo, TaskMetrics) => Double): Option[Distribution] = {
+    Distribution(taskInfoMetrics.map { case (info, metric) => getMetric(info, 
metric) })
   }
 
   // Is there some way to setup the types that I can get rid of this 
completely?
   def extractLongDistribution(
     taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
-    getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] 
= {
+    getMetric: (TaskInfo, TaskMetrics) => Long): Option[Distribution] = {
     extractDoubleDistribution(
       taskInfoMetrics,
-      (info, metric) => { getMetric(info, metric).map(_.toDouble) })
+      (info, metric) => { getMetric(info, metric).toDouble })
   }
 
   def showDistribution(heading: String, d: Distribution, formatNumber: Double 
=> String) {
@@ -117,9 +117,9 @@ private[spark] object StatsReportListener extends Logging {
   }
 
   def showDistribution(
-    heading: String,
-    dOpt: Option[Distribution],
-    formatNumber: Double => String) {
+      heading: String,
+      dOpt: Option[Distribution],
+      formatNumber: Double => String) {
     dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
   }
 
@@ -129,17 +129,17 @@ private[spark] object StatsReportListener extends Logging 
{
   }
 
   def showDistribution(
-    heading: String,
-    format: String,
-    getMetric: (TaskInfo, TaskMetrics) => Option[Double],
-    taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+      heading: String,
+      format: String,
+      getMetric: (TaskInfo, TaskMetrics) => Double,
+      taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
     showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, 
getMetric), format)
   }
 
   def showBytesDistribution(
-    heading: String,
-    getMetric: (TaskInfo, TaskMetrics) => Option[Long],
-    taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+      heading: String,
+      getMetric: (TaskInfo, TaskMetrics) => Long,
+      taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
     showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, 
getMetric))
   }
 
@@ -157,9 +157,9 @@ private[spark] object StatsReportListener extends Logging {
   }
 
   def showMillisDistribution(
-    heading: String,
-    getMetric: (TaskInfo, TaskMetrics) => Option[Long],
-    taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+      heading: String,
+      getMetric: (TaskInfo, TaskMetrics) => Long,
+      taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
     showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, 
getMetric))
   }
 
@@ -190,7 +190,7 @@ private case class RuntimePercentage(executorPct: Double, 
fetchPct: Option[Doubl
 private object RuntimePercentage {
   def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
     val denom = totalTime.toDouble
-    val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime)
+    val fetchTime = Some(metrics.shuffleReadMetrics.fetchWaitTime)
     val fetch = fetchTime.map(_ / denom)
     val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom
     val other = 1.0 - (exec + fetch.getOrElse(0d))

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala 
b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 876cdfa..5794f54 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -67,7 +67,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
     }
 
     // Update the context task metrics for each record read.
-    val readMetrics = context.taskMetrics.registerTempShuffleReadMetrics()
+    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
     val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
       recordIter.map { record =>
         readMetrics.incRecordsRead(1)

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala 
b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 9276d95..6c4444f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -41,7 +41,7 @@ private[spark] class HashShuffleWriter[K, V](
   // we don't try deleting files, etc twice.
   private var stopping = false
 
-  private val writeMetrics = metrics.registerShuffleWriteMetrics()
+  private val writeMetrics = metrics.shuffleWriteMetrics
 
   private val blockManager = SparkEnv.get.blockManager
   private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, 
numOutputSplits,

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala 
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 8ab1cee..1adacab 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -45,7 +45,7 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private var mapStatus: MapStatus = null
 
-  private val writeMetrics = 
context.taskMetrics().registerShuffleWriteMetrics()
+  private val writeMetrics = context.taskMetrics().shuffleWriteMetrics
 
   /** Write a bunch of records to this task's output */
   override def write(records: Iterator[Product2[K, V]]): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala 
b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
index f8d6e9f..85452d6 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
@@ -170,7 +170,11 @@ private[v1] object AllStagesResource {
     val inputMetrics: Option[InputMetricDistributions] =
       new MetricHelper[InternalInputMetrics, 
InputMetricDistributions](rawMetrics, quantiles) {
         def getSubmetrics(raw: InternalTaskMetrics): 
Option[InternalInputMetrics] = {
-          raw.inputMetrics
+          if (raw.inputMetrics.isUpdated) {
+            Some(raw.inputMetrics)
+          } else {
+            None
+          }
         }
 
         def build: InputMetricDistributions = new InputMetricDistributions(
@@ -182,7 +186,11 @@ private[v1] object AllStagesResource {
     val outputMetrics: Option[OutputMetricDistributions] =
       new MetricHelper[InternalOutputMetrics, 
OutputMetricDistributions](rawMetrics, quantiles) {
         def getSubmetrics(raw: InternalTaskMetrics): 
Option[InternalOutputMetrics] = {
-          raw.outputMetrics
+          if (raw.outputMetrics.isUpdated) {
+            Some(raw.outputMetrics)
+          } else {
+            None
+          }
         }
         def build: OutputMetricDistributions = new OutputMetricDistributions(
           bytesWritten = submetricQuantiles(_.bytesWritten),
@@ -194,7 +202,11 @@ private[v1] object AllStagesResource {
       new MetricHelper[InternalShuffleReadMetrics, 
ShuffleReadMetricDistributions](rawMetrics,
         quantiles) {
         def getSubmetrics(raw: InternalTaskMetrics): 
Option[InternalShuffleReadMetrics] = {
-          raw.shuffleReadMetrics
+          if (raw.shuffleReadMetrics.isUpdated) {
+            Some(raw.shuffleReadMetrics)
+          } else {
+            None
+          }
         }
         def build: ShuffleReadMetricDistributions = new 
ShuffleReadMetricDistributions(
           readBytes = submetricQuantiles(_.totalBytesRead),
@@ -211,7 +223,11 @@ private[v1] object AllStagesResource {
       new MetricHelper[InternalShuffleWriteMetrics, 
ShuffleWriteMetricDistributions](rawMetrics,
         quantiles) {
         def getSubmetrics(raw: InternalTaskMetrics): 
Option[InternalShuffleWriteMetrics] = {
-          raw.shuffleWriteMetrics
+          if (raw.shuffleWriteMetrics.isUpdated) {
+            Some(raw.shuffleWriteMetrics)
+          } else {
+            None
+          }
         }
         def build: ShuffleWriteMetricDistributions = new 
ShuffleWriteMetricDistributions(
           writeBytes = submetricQuantiles(_.bytesWritten),
@@ -250,44 +266,62 @@ private[v1] object AllStagesResource {
       resultSerializationTime = internal.resultSerializationTime,
       memoryBytesSpilled = internal.memoryBytesSpilled,
       diskBytesSpilled = internal.diskBytesSpilled,
-      inputMetrics = internal.inputMetrics.map { convertInputMetrics },
-      outputMetrics = Option(internal.outputMetrics).flatten.map { 
convertOutputMetrics },
-      shuffleReadMetrics = internal.shuffleReadMetrics.map { 
convertShuffleReadMetrics },
-      shuffleWriteMetrics = internal.shuffleWriteMetrics.map { 
convertShuffleWriteMetrics }
+      inputMetrics = convertInputMetrics(internal.inputMetrics),
+      outputMetrics = convertOutputMetrics(internal.outputMetrics),
+      shuffleReadMetrics = 
convertShuffleReadMetrics(internal.shuffleReadMetrics),
+      shuffleWriteMetrics = 
convertShuffleWriteMetrics(internal.shuffleWriteMetrics)
     )
   }
 
-  def convertInputMetrics(internal: InternalInputMetrics): InputMetrics = {
-    new InputMetrics(
-      bytesRead = internal.bytesRead,
-      recordsRead = internal.recordsRead
-    )
+  def convertInputMetrics(internal: InternalInputMetrics): 
Option[InputMetrics] = {
+    if (internal.isUpdated) {
+      Some(new InputMetrics(
+        bytesRead = internal.bytesRead,
+        recordsRead = internal.recordsRead
+      ))
+    } else {
+      None
+    }
   }
 
-  def convertOutputMetrics(internal: InternalOutputMetrics): OutputMetrics = {
-    new OutputMetrics(
-      bytesWritten = internal.bytesWritten,
-      recordsWritten = internal.recordsWritten
-    )
+  def convertOutputMetrics(internal: InternalOutputMetrics): 
Option[OutputMetrics] = {
+    if (internal.isUpdated) {
+      Some(new OutputMetrics(
+        bytesWritten = internal.bytesWritten,
+        recordsWritten = internal.recordsWritten
+      ))
+    } else {
+      None
+    }
   }
 
-  def convertShuffleReadMetrics(internal: InternalShuffleReadMetrics): 
ShuffleReadMetrics = {
-    new ShuffleReadMetrics(
-      remoteBlocksFetched = internal.remoteBlocksFetched,
-      localBlocksFetched = internal.localBlocksFetched,
-      fetchWaitTime = internal.fetchWaitTime,
-      remoteBytesRead = internal.remoteBytesRead,
-      totalBlocksFetched = internal.totalBlocksFetched,
-      recordsRead = internal.recordsRead
-    )
+  def convertShuffleReadMetrics(
+      internal: InternalShuffleReadMetrics): Option[ShuffleReadMetrics] = {
+    if (internal.isUpdated) {
+      Some(new ShuffleReadMetrics(
+        remoteBlocksFetched = internal.remoteBlocksFetched,
+        localBlocksFetched = internal.localBlocksFetched,
+        fetchWaitTime = internal.fetchWaitTime,
+        remoteBytesRead = internal.remoteBytesRead,
+        totalBlocksFetched = internal.totalBlocksFetched,
+        recordsRead = internal.recordsRead
+      ))
+    } else {
+      None
+    }
   }
 
-  def convertShuffleWriteMetrics(internal: InternalShuffleWriteMetrics): 
ShuffleWriteMetrics = {
-    new ShuffleWriteMetrics(
-      bytesWritten = internal.bytesWritten,
-      writeTime = internal.writeTime,
-      recordsWritten = internal.recordsWritten
-    )
+  def convertShuffleWriteMetrics(
+      internal: InternalShuffleWriteMetrics): Option[ShuffleWriteMetrics] = {
+    if ((internal.bytesWritten | internal.writeTime | internal.recordsWritten) 
== 0) {
+      None
+    } else {
+      Some(new ShuffleWriteMetrics(
+        bytesWritten = internal.bytesWritten,
+        writeTime = internal.writeTime,
+        recordsWritten = internal.recordsWritten
+      ))
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
 
b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 4ec5b4b..4dc2f36 100644
--- 
a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ 
b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -108,7 +108,7 @@ final class ShuffleBlockFetcherIterator(
   /** Current number of requests in flight */
   private[this] var reqsInFlight = 0
 
-  private[this] val shuffleMetrics = 
context.taskMetrics().registerTempShuffleReadMetrics()
+  private[this] val shuffleMetrics = 
context.taskMetrics().createTempShuffleReadMetrics()
 
   /**
    * Whether the iterator is still active. If isZombie is true, the callback 
interface will no

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala 
b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index 3fd0efd..676f445 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -119,26 +119,19 @@ class ExecutorsListener(storageStatusListener: 
StorageStatusListener, conf: Spar
       // Update shuffle read/write
       val metrics = taskEnd.taskMetrics
       if (metrics != null) {
-        metrics.inputMetrics.foreach { inputMetrics =>
-          executorToInputBytes(eid) =
-            executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead
-          executorToInputRecords(eid) =
-            executorToInputRecords.getOrElse(eid, 0L) + 
inputMetrics.recordsRead
-        }
-        metrics.outputMetrics.foreach { outputMetrics =>
-          executorToOutputBytes(eid) =
-            executorToOutputBytes.getOrElse(eid, 0L) + 
outputMetrics.bytesWritten
-          executorToOutputRecords(eid) =
-            executorToOutputRecords.getOrElse(eid, 0L) + 
outputMetrics.recordsWritten
-        }
-        metrics.shuffleReadMetrics.foreach { shuffleRead =>
-          executorToShuffleRead(eid) =
-            executorToShuffleRead.getOrElse(eid, 0L) + 
shuffleRead.remoteBytesRead
-        }
-        metrics.shuffleWriteMetrics.foreach { shuffleWrite =>
-          executorToShuffleWrite(eid) =
-            executorToShuffleWrite.getOrElse(eid, 0L) + 
shuffleWrite.bytesWritten
-        }
+        executorToInputBytes(eid) =
+          executorToInputBytes.getOrElse(eid, 0L) + 
metrics.inputMetrics.bytesRead
+        executorToInputRecords(eid) =
+          executorToInputRecords.getOrElse(eid, 0L) + 
metrics.inputMetrics.recordsRead
+        executorToOutputBytes(eid) =
+          executorToOutputBytes.getOrElse(eid, 0L) + 
metrics.outputMetrics.bytesWritten
+        executorToOutputRecords(eid) =
+          executorToOutputRecords.getOrElse(eid, 0L) + 
metrics.outputMetrics.recordsWritten
+
+        executorToShuffleRead(eid) =
+          executorToShuffleRead.getOrElse(eid, 0L) + 
metrics.shuffleReadMetrics.remoteBytesRead
+        executorToShuffleWrite(eid) =
+          executorToShuffleWrite.getOrElse(eid, 0L) + 
metrics.shuffleWriteMetrics.bytesWritten
         executorToJvmGCTime(eid) = executorToJvmGCTime.getOrElse(eid, 0L) + 
metrics.jvmGCTime
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala 
b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 13f5f84..9e4771c 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -434,50 +434,50 @@ class JobProgressListener(conf: SparkConf) extends 
SparkListener with Logging {
     val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new 
ExecutorSummary)
 
     val shuffleWriteDelta =
-      (taskMetrics.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L)
-      - 
oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.bytesWritten).getOrElse(0L))
+      taskMetrics.shuffleWriteMetrics.bytesWritten -
+        oldMetrics.map(_.shuffleWriteMetrics.bytesWritten).getOrElse(0L)
     stageData.shuffleWriteBytes += shuffleWriteDelta
     execSummary.shuffleWrite += shuffleWriteDelta
 
     val shuffleWriteRecordsDelta =
-      (taskMetrics.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L)
-      - 
oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.recordsWritten).getOrElse(0L))
+      taskMetrics.shuffleWriteMetrics.recordsWritten -
+        oldMetrics.map(_.shuffleWriteMetrics.recordsWritten).getOrElse(0L)
     stageData.shuffleWriteRecords += shuffleWriteRecordsDelta
     execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta
 
     val shuffleReadDelta =
-      (taskMetrics.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L)
-        - 
oldMetrics.flatMap(_.shuffleReadMetrics).map(_.totalBytesRead).getOrElse(0L))
+      taskMetrics.shuffleReadMetrics.totalBytesRead -
+        oldMetrics.map(_.shuffleReadMetrics.totalBytesRead).getOrElse(0L)
     stageData.shuffleReadTotalBytes += shuffleReadDelta
     execSummary.shuffleRead += shuffleReadDelta
 
     val shuffleReadRecordsDelta =
-      (taskMetrics.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L)
-      - 
oldMetrics.flatMap(_.shuffleReadMetrics).map(_.recordsRead).getOrElse(0L))
+      taskMetrics.shuffleReadMetrics.recordsRead -
+        oldMetrics.map(_.shuffleReadMetrics.recordsRead).getOrElse(0L)
     stageData.shuffleReadRecords += shuffleReadRecordsDelta
     execSummary.shuffleReadRecords += shuffleReadRecordsDelta
 
     val inputBytesDelta =
-      (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L)
-      - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L))
+      taskMetrics.inputMetrics.bytesRead -
+        oldMetrics.map(_.inputMetrics.bytesRead).getOrElse(0L)
     stageData.inputBytes += inputBytesDelta
     execSummary.inputBytes += inputBytesDelta
 
     val inputRecordsDelta =
-      (taskMetrics.inputMetrics.map(_.recordsRead).getOrElse(0L)
-      - oldMetrics.flatMap(_.inputMetrics).map(_.recordsRead).getOrElse(0L))
+      taskMetrics.inputMetrics.recordsRead -
+        oldMetrics.map(_.inputMetrics.recordsRead).getOrElse(0L)
     stageData.inputRecords += inputRecordsDelta
     execSummary.inputRecords += inputRecordsDelta
 
     val outputBytesDelta =
-      (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L)
-        - 
oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L))
+      taskMetrics.outputMetrics.bytesWritten -
+        oldMetrics.map(_.outputMetrics.bytesWritten).getOrElse(0L)
     stageData.outputBytes += outputBytesDelta
     execSummary.outputBytes += outputBytesDelta
 
     val outputRecordsDelta =
-      (taskMetrics.outputMetrics.map(_.recordsWritten).getOrElse(0L)
-        - 
oldMetrics.flatMap(_.outputMetrics).map(_.recordsWritten).getOrElse(0L))
+      taskMetrics.outputMetrics.recordsWritten -
+        oldMetrics.map(_.outputMetrics.recordsWritten).getOrElse(0L)
     stageData.outputRecords += outputRecordsDelta
     execSummary.outputRecords += outputRecordsDelta
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala 
b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 8a44bbd..5d1928a 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -428,29 +428,29 @@ private[ui] class StagePage(parent: StagesTab) extends 
WebUIPage("stage") {
           }
 
           val inputSizes = validTasks.map { taskUIData: TaskUIData =>
-            
taskUIData.metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble
+            taskUIData.metrics.get.inputMetrics.bytesRead.toDouble
           }
 
           val inputRecords = validTasks.map { taskUIData: TaskUIData =>
-            
taskUIData.metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble
+            taskUIData.metrics.get.inputMetrics.recordsRead.toDouble
           }
 
           val inputQuantiles = <td>Input Size / Records</td> +:
             getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords)
 
           val outputSizes = validTasks.map { taskUIData: TaskUIData =>
-            
taskUIData.metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble
+            taskUIData.metrics.get.outputMetrics.bytesWritten.toDouble
           }
 
           val outputRecords = validTasks.map { taskUIData: TaskUIData =>
-            
taskUIData.metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble
+            taskUIData.metrics.get.outputMetrics.recordsWritten.toDouble
           }
 
           val outputQuantiles = <td>Output Size / Records</td> +:
             getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords)
 
           val shuffleReadBlockedTimes = validTasks.map { taskUIData: 
TaskUIData =>
-            
taskUIData.metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble
+            taskUIData.metrics.get.shuffleReadMetrics.fetchWaitTime.toDouble
           }
           val shuffleReadBlockedQuantiles =
             <td>
@@ -462,10 +462,10 @@ private[ui] class StagePage(parent: StagesTab) extends 
WebUIPage("stage") {
             getFormattedTimeQuantiles(shuffleReadBlockedTimes)
 
           val shuffleReadTotalSizes = validTasks.map { taskUIData: TaskUIData 
=>
-            
taskUIData.metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble
+            taskUIData.metrics.get.shuffleReadMetrics.totalBytesRead.toDouble
           }
           val shuffleReadTotalRecords = validTasks.map { taskUIData: 
TaskUIData =>
-            
taskUIData.metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble
+            taskUIData.metrics.get.shuffleReadMetrics.recordsRead.toDouble
           }
           val shuffleReadTotalQuantiles =
             <td>
@@ -477,7 +477,7 @@ private[ui] class StagePage(parent: StagesTab) extends 
WebUIPage("stage") {
             getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, 
shuffleReadTotalRecords)
 
           val shuffleReadRemoteSizes = validTasks.map { taskUIData: TaskUIData 
=>
-            
taskUIData.metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
+            taskUIData.metrics.get.shuffleReadMetrics.remoteBytesRead.toDouble
           }
           val shuffleReadRemoteQuantiles =
             <td>
@@ -489,11 +489,11 @@ private[ui] class StagePage(parent: StagesTab) extends 
WebUIPage("stage") {
             getFormattedSizeQuantiles(shuffleReadRemoteSizes)
 
           val shuffleWriteSizes = validTasks.map { taskUIData: TaskUIData =>
-            
taskUIData.metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble
+            taskUIData.metrics.get.shuffleWriteMetrics.bytesWritten.toDouble
           }
 
           val shuffleWriteRecords = validTasks.map { taskUIData: TaskUIData =>
-            
taskUIData.metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble
+            taskUIData.metrics.get.shuffleWriteMetrics.recordsWritten.toDouble
           }
 
           val shuffleWriteQuantiles = <td>Shuffle Write Size / Records</td> +:
@@ -603,11 +603,10 @@ private[ui] class StagePage(parent: StagesTab) extends 
WebUIPage("stage") {
 
         val metricsOpt = taskUIData.metrics
         val shuffleReadTime =
-          
metricsOpt.flatMap(_.shuffleReadMetrics.map(_.fetchWaitTime)).getOrElse(0L)
+          metricsOpt.map(_.shuffleReadMetrics.fetchWaitTime).getOrElse(0L)
         val shuffleReadTimeProportion = toProportion(shuffleReadTime)
         val shuffleWriteTime =
-          (metricsOpt.flatMap(_.shuffleWriteMetrics
-            .map(_.writeTime)).getOrElse(0L) / 1e6).toLong
+          (metricsOpt.map(_.shuffleWriteMetrics.writeTime).getOrElse(0L) / 
1e6).toLong
         val shuffleWriteTimeProportion = toProportion(shuffleWriteTime)
 
         val serializationTime = 
metricsOpt.map(_.resultSerializationTime).getOrElse(0L)
@@ -890,21 +889,21 @@ private[ui] class TaskDataSource(
       }
     val peakExecutionMemoryUsed = 
metrics.map(_.peakExecutionMemory).getOrElse(0L)
 
-    val maybeInput = metrics.flatMap(_.inputMetrics)
+    val maybeInput = metrics.map(_.inputMetrics)
     val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)
     val inputReadable = maybeInput
-      .map(m => s"${Utils.bytesToString(m.bytesRead)} 
(${m.readMethod.toString.toLowerCase()})")
+      .map(m => s"${Utils.bytesToString(m.bytesRead)}")
       .getOrElse("")
     val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("")
 
-    val maybeOutput = metrics.flatMap(_.outputMetrics)
+    val maybeOutput = metrics.map(_.outputMetrics)
     val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L)
     val outputReadable = maybeOutput
       .map(m => s"${Utils.bytesToString(m.bytesWritten)}")
       .getOrElse("")
     val outputRecords = 
maybeOutput.map(_.recordsWritten.toString).getOrElse("")
 
-    val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics)
+    val maybeShuffleRead = metrics.map(_.shuffleReadMetrics)
     val shuffleReadBlockedTimeSortable = 
maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L)
     val shuffleReadBlockedTimeReadable =
       maybeShuffleRead.map(ms => 
UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("")
@@ -918,14 +917,14 @@ private[ui] class TaskDataSource(
     val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L)
     val shuffleReadRemoteReadable = 
remoteShuffleBytes.map(Utils.bytesToString).getOrElse("")
 
-    val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics)
+    val maybeShuffleWrite = metrics.map(_.shuffleWriteMetrics)
     val shuffleWriteSortable = 
maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L)
     val shuffleWriteReadable = maybeShuffleWrite
       .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("")
     val shuffleWriteRecords = maybeShuffleWrite
       .map(_.recordsWritten.toString).getOrElse("")
 
-    val maybeWriteTime = 
metrics.flatMap(_.shuffleWriteMetrics).map(_.writeTime)
+    val maybeWriteTime = metrics.map(_.shuffleWriteMetrics.writeTime)
     val writeTimeSortable = maybeWriteTime.getOrElse(0L)
     val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { 
ms =>
       if (ms == 0) "" else UIUtils.formatDuration(ms)

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala 
b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 558767e..17b33c7 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -326,33 +326,35 @@ private[spark] object JsonProtocol {
   }
 
   def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = {
-    val shuffleReadMetrics: JValue =
-      taskMetrics.shuffleReadMetrics.map { rm =>
-        ("Remote Blocks Fetched" -> rm.remoteBlocksFetched) ~
-        ("Local Blocks Fetched" -> rm.localBlocksFetched) ~
-        ("Fetch Wait Time" -> rm.fetchWaitTime) ~
-        ("Remote Bytes Read" -> rm.remoteBytesRead) ~
-        ("Local Bytes Read" -> rm.localBytesRead) ~
-        ("Total Records Read" -> rm.recordsRead)
-      }.getOrElse(JNothing)
-    val shuffleWriteMetrics: JValue =
-      taskMetrics.shuffleWriteMetrics.map { wm =>
-        ("Shuffle Bytes Written" -> wm.bytesWritten) ~
-        ("Shuffle Write Time" -> wm.writeTime) ~
-        ("Shuffle Records Written" -> wm.recordsWritten)
-      }.getOrElse(JNothing)
-    val inputMetrics: JValue =
-      taskMetrics.inputMetrics.map { im =>
-        ("Data Read Method" -> im.readMethod.toString) ~
-        ("Bytes Read" -> im.bytesRead) ~
-        ("Records Read" -> im.recordsRead)
-      }.getOrElse(JNothing)
-    val outputMetrics: JValue =
-      taskMetrics.outputMetrics.map { om =>
-        ("Data Write Method" -> om.writeMethod.toString) ~
-        ("Bytes Written" -> om.bytesWritten) ~
-        ("Records Written" -> om.recordsWritten)
-      }.getOrElse(JNothing)
+    val shuffleReadMetrics: JValue = if 
(taskMetrics.shuffleReadMetrics.isUpdated) {
+      ("Remote Blocks Fetched" -> 
taskMetrics.shuffleReadMetrics.remoteBlocksFetched) ~
+        ("Local Blocks Fetched" -> 
taskMetrics.shuffleReadMetrics.localBlocksFetched) ~
+        ("Fetch Wait Time" -> taskMetrics.shuffleReadMetrics.fetchWaitTime) ~
+        ("Remote Bytes Read" -> 
taskMetrics.shuffleReadMetrics.remoteBytesRead) ~
+        ("Local Bytes Read" -> taskMetrics.shuffleReadMetrics.localBytesRead) ~
+        ("Total Records Read" -> taskMetrics.shuffleReadMetrics.recordsRead)
+    } else {
+      JNothing
+    }
+    val shuffleWriteMetrics: JValue = if 
(taskMetrics.shuffleWriteMetrics.isUpdated) {
+      ("Shuffle Bytes Written" -> 
taskMetrics.shuffleWriteMetrics.bytesWritten) ~
+        ("Shuffle Write Time" -> taskMetrics.shuffleWriteMetrics.writeTime) ~
+        ("Shuffle Records Written" -> 
taskMetrics.shuffleWriteMetrics.recordsWritten)
+    } else {
+      JNothing
+    }
+    val inputMetrics: JValue = if (taskMetrics.inputMetrics.isUpdated) {
+      ("Bytes Read" -> taskMetrics.inputMetrics.bytesRead) ~
+        ("Records Read" -> taskMetrics.inputMetrics.recordsRead)
+    } else {
+      JNothing
+    }
+    val outputMetrics: JValue = if (taskMetrics.outputMetrics.isUpdated) {
+      ("Bytes Written" -> taskMetrics.outputMetrics.bytesWritten) ~
+        ("Records Written" -> taskMetrics.outputMetrics.recordsWritten)
+    } else {
+      JNothing
+    }
     val updatedBlocks =
       JArray(taskMetrics.updatedBlockStatuses.toList.map { case (id, status) =>
         ("Block ID" -> id.toString) ~
@@ -781,7 +783,7 @@ private[spark] object JsonProtocol {
 
     // Shuffle read metrics
     Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson =>
-      val readMetrics = metrics.registerTempShuffleReadMetrics()
+      val readMetrics = metrics.createTempShuffleReadMetrics()
       readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks 
Fetched").extract[Int])
       readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks 
Fetched").extract[Int])
       readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes 
Read").extract[Long])
@@ -794,7 +796,7 @@ private[spark] object JsonProtocol {
     // Shuffle write metrics
     // TODO: Drop the redundant "Shuffle" since it's inconsistent with related 
classes.
     Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson =>
-      val writeMetrics = metrics.registerShuffleWriteMetrics()
+      val writeMetrics = metrics.shuffleWriteMetrics
       writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes 
Written").extract[Long])
       writeMetrics.incRecordsWritten((writeJson \ "Shuffle Records Written")
         .extractOpt[Long].getOrElse(0L))
@@ -803,16 +805,14 @@ private[spark] object JsonProtocol {
 
     // Output metrics
     Utils.jsonOption(json \ "Output Metrics").foreach { outJson =>
-      val writeMethod = DataWriteMethod.withName((outJson \ "Data Write 
Method").extract[String])
-      val outputMetrics = metrics.registerOutputMetrics(writeMethod)
+      val outputMetrics = metrics.outputMetrics
       outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long])
       outputMetrics.setRecordsWritten((outJson \ "Records 
Written").extractOpt[Long].getOrElse(0L))
     }
 
     // Input metrics
     Utils.jsonOption(json \ "Input Metrics").foreach { inJson =>
-      val readMethod = DataReadMethod.withName((inJson \ "Data Read 
Method").extract[String])
-      val inputMetrics = metrics.registerInputMetrics(readMethod)
+      val inputMetrics = metrics.inputMetrics
       inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long])
       inputMetrics.incRecordsRead((inJson \ "Records 
Read").extractOpt[Long].getOrElse(0L))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 561ba22..916053f 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -645,7 +645,7 @@ private[spark] class ExternalSorter[K, V, C](
       blockId: BlockId,
       outputFile: File): Array[Long] = {
 
-    val writeMetrics = context.taskMetrics().registerShuffleWriteMetrics()
+    val writeMetrics = context.taskMetrics().shuffleWriteMetrics
 
     // Track location of each range in the output file
     val lengths = new Array[Long](numPartitions)

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 30750b1..fbaaa1c 100644
--- 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -249,8 +249,8 @@ public class UnsafeShuffleWriterSuite {
     assertTrue(mapStatus.isDefined());
     assertTrue(mergedOutputFile.exists());
     assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
-    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().recordsWritten());
-    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().bytesWritten());
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten());
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().bytesWritten());
     assertEquals(0, taskMetrics.diskBytesSpilled());
     assertEquals(0, taskMetrics.memoryBytesSpilled());
   }
@@ -279,7 +279,7 @@ public class UnsafeShuffleWriterSuite {
       HashMultiset.create(dataToWrite),
       HashMultiset.create(readRecordsFromFile()));
     assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
+    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics();
     assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
     assertEquals(0, taskMetrics.diskBytesSpilled());
     assertEquals(0, taskMetrics.memoryBytesSpilled());
@@ -321,7 +321,7 @@ public class UnsafeShuffleWriterSuite {
 
     assertEquals(HashMultiset.create(dataToWrite), 
HashMultiset.create(readRecordsFromFile()));
     assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
+    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics();
     assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
     assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
     assertThat(taskMetrics.diskBytesSpilled(), 
lessThan(mergedOutputFile.length()));
@@ -383,7 +383,7 @@ public class UnsafeShuffleWriterSuite {
     writer.stop(true);
     readRecordsFromFile();
     assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
+    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics();
     assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
     assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
     assertThat(taskMetrics.diskBytesSpilled(), 
lessThan(mergedOutputFile.length()));
@@ -404,7 +404,7 @@ public class UnsafeShuffleWriterSuite {
     writer.stop(true);
     readRecordsFromFile();
     assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
+    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics();
     assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten());
     assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
     assertThat(taskMetrics.diskBytesSpilled(), 
lessThan(mergedOutputFile.length()));

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala 
b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
index 4745506..db087a9 100644
--- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -59,11 +59,9 @@ class InternalAccumulatorSuite extends SparkFunSuite with 
LocalSparkContext {
     assert(getParam(shuffleWrite.RECORDS_WRITTEN) === LongAccumulatorParam)
     assert(getParam(shuffleWrite.WRITE_TIME) === LongAccumulatorParam)
     // input
-    assert(getParam(input.READ_METHOD) === StringAccumulatorParam)
     assert(getParam(input.RECORDS_READ) === LongAccumulatorParam)
     assert(getParam(input.BYTES_READ) === LongAccumulatorParam)
     // output
-    assert(getParam(output.WRITE_METHOD) === StringAccumulatorParam)
     assert(getParam(output.RECORDS_WRITTEN) === LongAccumulatorParam)
     assert(getParam(output.BYTES_WRITTEN) === LongAccumulatorParam)
     // default to Long
@@ -77,18 +75,15 @@ class InternalAccumulatorSuite extends SparkFunSuite with 
LocalSparkContext {
     val executorRunTime = create(EXECUTOR_RUN_TIME)
     val updatedBlockStatuses = create(UPDATED_BLOCK_STATUSES)
     val shuffleRemoteBlocksRead = create(shuffleRead.REMOTE_BLOCKS_FETCHED)
-    val inputReadMethod = create(input.READ_METHOD)
     assert(executorRunTime.name === Some(EXECUTOR_RUN_TIME))
     assert(updatedBlockStatuses.name === Some(UPDATED_BLOCK_STATUSES))
     assert(shuffleRemoteBlocksRead.name === 
Some(shuffleRead.REMOTE_BLOCKS_FETCHED))
-    assert(inputReadMethod.name === Some(input.READ_METHOD))
     assert(executorRunTime.value.isInstanceOf[Long])
     assert(updatedBlockStatuses.value.isInstanceOf[Seq[_]])
     // We cannot assert the type of the value directly since the type 
parameter is erased.
     // Instead, try casting a `Seq` of expected type and see if it fails in 
run time.
     updatedBlockStatuses.setValueAny(Seq.empty[(BlockId, BlockStatus)])
     assert(shuffleRemoteBlocksRead.value.isInstanceOf[Int])
-    assert(inputReadMethod.value.isInstanceOf[String])
     // default to Long
     val anything = create(METRICS_PREFIX + "anything")
     assert(anything.value.isInstanceOf[Long])

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala 
b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index cd7d2e1..079109d 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -450,14 +450,10 @@ object ShuffleSuite {
     @volatile var bytesRead: Long = 0
     val listener = new SparkListener {
       override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
-        taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m =>
-          recordsWritten += m.recordsWritten
-          bytesWritten += m.bytesWritten
-        }
-        taskEnd.taskMetrics.shuffleReadMetrics.foreach { m =>
-          recordsRead += m.recordsRead
-          bytesRead += m.totalBytesRead
-        }
+        recordsWritten += 
taskEnd.taskMetrics.shuffleWriteMetrics.recordsWritten
+        bytesWritten += taskEnd.taskMetrics.shuffleWriteMetrics.bytesWritten
+        recordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
+        bytesRead += taskEnd.taskMetrics.shuffleReadMetrics.totalBytesRead
       }
     }
     sc.addSparkListener(listener)

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala 
b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index d91f50f..a263fce 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -30,24 +30,6 @@ class TaskMetricsSuite extends SparkFunSuite {
   import StorageLevel._
   import TaskMetricsSuite._
 
-  test("create") {
-    val internalAccums = InternalAccumulator.createAll()
-    val tm1 = new TaskMetrics
-    val tm2 = new TaskMetrics(internalAccums)
-    assert(tm1.accumulatorUpdates().size === internalAccums.size)
-    assert(tm1.shuffleReadMetrics.isEmpty)
-    assert(tm1.shuffleWriteMetrics.isEmpty)
-    assert(tm1.inputMetrics.isEmpty)
-    assert(tm1.outputMetrics.isEmpty)
-    assert(tm2.accumulatorUpdates().size === internalAccums.size)
-    assert(tm2.shuffleReadMetrics.isEmpty)
-    assert(tm2.shuffleWriteMetrics.isEmpty)
-    assert(tm2.inputMetrics.isEmpty)
-    assert(tm2.outputMetrics.isEmpty)
-    // TaskMetrics constructor expects minimal set of initial accumulators
-    intercept[IllegalArgumentException] { new 
TaskMetrics(Seq.empty[Accumulator[_]]) }
-  }
-
   test("create with unnamed accum") {
     intercept[IllegalArgumentException] {
       new TaskMetrics(
@@ -110,11 +92,9 @@ class TaskMetricsSuite extends SparkFunSuite {
       .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
     accums(BYTES_READ).setValueAny(1L)
     accums(RECORDS_READ).setValueAny(2L)
-    accums(READ_METHOD).setValueAny(DataReadMethod.Hadoop.toString)
     val im = new InputMetrics(accums)
     assert(im.bytesRead === 1L)
     assert(im.recordsRead === 2L)
-    assert(im.readMethod === DataReadMethod.Hadoop)
   }
 
   test("create output metrics") {
@@ -123,11 +103,9 @@ class TaskMetricsSuite extends SparkFunSuite {
       .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
     accums(BYTES_WRITTEN).setValueAny(1L)
     accums(RECORDS_WRITTEN).setValueAny(2L)
-    accums(WRITE_METHOD).setValueAny(DataWriteMethod.Hadoop.toString)
     val om = new OutputMetrics(accums)
     assert(om.bytesWritten === 1L)
     assert(om.recordsWritten === 2L)
-    assert(om.writeMethod === DataWriteMethod.Hadoop)
   }
 
   test("mutating values") {
@@ -183,14 +161,12 @@ class TaskMetricsSuite extends SparkFunSuite {
     val accums = InternalAccumulator.createAll()
     val tm = new TaskMetrics(accums)
     def assertValEquals[T](tmValue: ShuffleReadMetrics => T, name: String, 
value: T): Unit = {
-      assertValueEquals(tm, tm => tmValue(tm.shuffleReadMetrics.get), accums, 
name, value)
+      assertValueEquals(tm, tm => tmValue(tm.shuffleReadMetrics), accums, 
name, value)
     }
     // create shuffle read metrics
-    assert(tm.shuffleReadMetrics.isEmpty)
-    tm.registerTempShuffleReadMetrics()
+    tm.createTempShuffleReadMetrics()
     tm.mergeShuffleReadMetrics()
-    assert(tm.shuffleReadMetrics.isDefined)
-    val sr = tm.shuffleReadMetrics.get
+    val sr = tm.shuffleReadMetrics
     // initial values
     assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 0)
     assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 0)
@@ -237,13 +213,10 @@ class TaskMetricsSuite extends SparkFunSuite {
     val accums = InternalAccumulator.createAll()
     val tm = new TaskMetrics(accums)
     def assertValEquals[T](tmValue: ShuffleWriteMetrics => T, name: String, 
value: T): Unit = {
-      assertValueEquals(tm, tm => tmValue(tm.shuffleWriteMetrics.get), accums, 
name, value)
+      assertValueEquals(tm, tm => tmValue(tm.shuffleWriteMetrics), accums, 
name, value)
     }
     // create shuffle write metrics
-    assert(tm.shuffleWriteMetrics.isEmpty)
-    tm.registerShuffleWriteMetrics()
-    assert(tm.shuffleWriteMetrics.isDefined)
-    val sw = tm.shuffleWriteMetrics.get
+    val sw = tm.shuffleWriteMetrics
     // initial values
     assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L)
     assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L)
@@ -270,28 +243,22 @@ class TaskMetricsSuite extends SparkFunSuite {
     val accums = InternalAccumulator.createAll()
     val tm = new TaskMetrics(accums)
     def assertValEquals(tmValue: InputMetrics => Any, name: String, value: 
Any): Unit = {
-      assertValueEquals(tm, tm => tmValue(tm.inputMetrics.get), accums, name, 
value,
+      assertValueEquals(tm, tm => tmValue(tm.inputMetrics), accums, name, 
value,
         (x: Any, y: Any) => assert(x.toString === y.toString))
     }
     // create input metrics
-    assert(tm.inputMetrics.isEmpty)
-    tm.registerInputMetrics(DataReadMethod.Memory)
-    assert(tm.inputMetrics.isDefined)
-    val in = tm.inputMetrics.get
+    val in = tm.inputMetrics
     // initial values
     assertValEquals(_.bytesRead, BYTES_READ, 0L)
     assertValEquals(_.recordsRead, RECORDS_READ, 0L)
-    assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Memory)
     // set and increment values
     in.setBytesRead(1L)
     in.setBytesRead(2L)
     in.incRecordsRead(1L)
     in.incRecordsRead(2L)
-    in.setReadMethod(DataReadMethod.Disk)
     // assert new values exist
     assertValEquals(_.bytesRead, BYTES_READ, 2L)
     assertValEquals(_.recordsRead, RECORDS_READ, 3L)
-    assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Disk)
   }
 
   test("mutating output metrics values") {
@@ -299,85 +266,42 @@ class TaskMetricsSuite extends SparkFunSuite {
     val accums = InternalAccumulator.createAll()
     val tm = new TaskMetrics(accums)
     def assertValEquals(tmValue: OutputMetrics => Any, name: String, value: 
Any): Unit = {
-      assertValueEquals(tm, tm => tmValue(tm.outputMetrics.get), accums, name, 
value,
+      assertValueEquals(tm, tm => tmValue(tm.outputMetrics), accums, name, 
value,
         (x: Any, y: Any) => assert(x.toString === y.toString))
     }
     // create input metrics
-    assert(tm.outputMetrics.isEmpty)
-    tm.registerOutputMetrics(DataWriteMethod.Hadoop)
-    assert(tm.outputMetrics.isDefined)
-    val out = tm.outputMetrics.get
+    val out = tm.outputMetrics
     // initial values
     assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L)
     assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L)
-    assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop)
     // set values
     out.setBytesWritten(1L)
     out.setBytesWritten(2L)
     out.setRecordsWritten(3L)
     out.setRecordsWritten(4L)
-    out.setWriteMethod(DataWriteMethod.Hadoop)
     // assert new values exist
     assertValEquals(_.bytesWritten, BYTES_WRITTEN, 2L)
     assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 4L)
-    // Note: this doesn't actually test anything, but there's only one 
DataWriteMethod
-    // so we can't set it to anything else
-    assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop)
   }
 
   test("merging multiple shuffle read metrics") {
     val tm = new TaskMetrics
-    assert(tm.shuffleReadMetrics.isEmpty)
-    val sr1 = tm.registerTempShuffleReadMetrics()
-    val sr2 = tm.registerTempShuffleReadMetrics()
-    val sr3 = tm.registerTempShuffleReadMetrics()
-    assert(tm.shuffleReadMetrics.isEmpty)
+    val sr1 = tm.createTempShuffleReadMetrics()
+    val sr2 = tm.createTempShuffleReadMetrics()
+    val sr3 = tm.createTempShuffleReadMetrics()
     sr1.setRecordsRead(10L)
     sr2.setRecordsRead(10L)
     sr1.setFetchWaitTime(1L)
     sr2.setFetchWaitTime(2L)
     sr3.setFetchWaitTime(3L)
     tm.mergeShuffleReadMetrics()
-    assert(tm.shuffleReadMetrics.isDefined)
-    val sr = tm.shuffleReadMetrics.get
-    assert(sr.remoteBlocksFetched === 0L)
-    assert(sr.recordsRead === 20L)
-    assert(sr.fetchWaitTime === 6L)
+    assert(tm.shuffleReadMetrics.remoteBlocksFetched === 0L)
+    assert(tm.shuffleReadMetrics.recordsRead === 20L)
+    assert(tm.shuffleReadMetrics.fetchWaitTime === 6L)
 
     // SPARK-5701: calling merge without any shuffle deps does nothing
     val tm2 = new TaskMetrics
     tm2.mergeShuffleReadMetrics()
-    assert(tm2.shuffleReadMetrics.isEmpty)
-  }
-
-  test("register multiple shuffle write metrics") {
-    val tm = new TaskMetrics
-    val sw1 = tm.registerShuffleWriteMetrics()
-    val sw2 = tm.registerShuffleWriteMetrics()
-    assert(sw1 === sw2)
-    assert(tm.shuffleWriteMetrics === Some(sw1))
-  }
-
-  test("register multiple input metrics") {
-    val tm = new TaskMetrics
-    val im1 = tm.registerInputMetrics(DataReadMethod.Memory)
-    val im2 = tm.registerInputMetrics(DataReadMethod.Memory)
-    // input metrics with a different read method than the one already 
registered are ignored
-    val im3 = tm.registerInputMetrics(DataReadMethod.Hadoop)
-    assert(im1 === im2)
-    assert(im1 !== im3)
-    assert(tm.inputMetrics === Some(im1))
-    im2.setBytesRead(50L)
-    im3.setBytesRead(100L)
-    assert(tm.inputMetrics.get.bytesRead === 50L)
-  }
-
-  test("register multiple output metrics") {
-    val tm = new TaskMetrics
-    val om1 = tm.registerOutputMetrics(DataWriteMethod.Hadoop)
-    val om2 = tm.registerOutputMetrics(DataWriteMethod.Hadoop)
-    assert(om1 === om2)
-    assert(tm.outputMetrics === Some(om1))
   }
 
   test("additional accumulables") {
@@ -424,10 +348,6 @@ class TaskMetricsSuite extends SparkFunSuite {
     assert(srAccum.isDefined)
     srAccum.get.asInstanceOf[Accumulator[Long]] += 10L
     val tm = new TaskMetrics(accums)
-    assert(tm.shuffleReadMetrics.isDefined)
-    assert(tm.shuffleWriteMetrics.isEmpty)
-    assert(tm.inputMetrics.isEmpty)
-    assert(tm.outputMetrics.isEmpty)
   }
 
   test("existing values in shuffle write accums") {
@@ -437,10 +357,6 @@ class TaskMetricsSuite extends SparkFunSuite {
     assert(swAccum.isDefined)
     swAccum.get.asInstanceOf[Accumulator[Long]] += 10L
     val tm = new TaskMetrics(accums)
-    assert(tm.shuffleReadMetrics.isEmpty)
-    assert(tm.shuffleWriteMetrics.isDefined)
-    assert(tm.inputMetrics.isEmpty)
-    assert(tm.outputMetrics.isEmpty)
   }
 
   test("existing values in input accums") {
@@ -450,10 +366,6 @@ class TaskMetricsSuite extends SparkFunSuite {
     assert(inAccum.isDefined)
     inAccum.get.asInstanceOf[Accumulator[Long]] += 10L
     val tm = new TaskMetrics(accums)
-    assert(tm.shuffleReadMetrics.isEmpty)
-    assert(tm.shuffleWriteMetrics.isEmpty)
-    assert(tm.inputMetrics.isDefined)
-    assert(tm.outputMetrics.isEmpty)
   }
 
   test("existing values in output accums") {
@@ -463,10 +375,6 @@ class TaskMetricsSuite extends SparkFunSuite {
     assert(outAccum.isDefined)
     outAccum.get.asInstanceOf[Accumulator[Long]] += 10L
     val tm4 = new TaskMetrics(accums)
-    assert(tm4.shuffleReadMetrics.isEmpty)
-    assert(tm4.shuffleWriteMetrics.isEmpty)
-    assert(tm4.inputMetrics.isEmpty)
-    assert(tm4.outputMetrics.isDefined)
   }
 
   test("from accumulator updates") {

http://git-wip-us.apache.org/repos/asf/spark/blob/8028a288/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala 
b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 056e546..f8054f5 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -25,16 +25,10 @@ import org.apache.commons.lang3.RandomUtils
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileSystem, Path}
 import org.apache.hadoop.io.{LongWritable, Text}
-import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => 
OldInputSplit,
-  JobConf, LineRecordReader => OldLineRecordReader, RecordReader => 
OldRecordReader,
-  Reporter, TextInputFormat => OldTextInputFormat}
-import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => 
OldCombineFileInputFormat,
-  CombineFileRecordReader => OldCombineFileRecordReader, CombineFileSplit => 
OldCombineFileSplit}
-import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader 
=> NewRecordReader,
-  TaskAttemptContext}
-import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => 
NewCombineFileInputFormat,
-  CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => 
NewCombineFileSplit,
-  FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat}
+import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => 
OldInputSplit, JobConf, LineRecordReader => OldLineRecordReader, RecordReader 
=> OldRecordReader, Reporter, TextInputFormat => OldTextInputFormat}
+import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => 
OldCombineFileInputFormat, CombineFileRecordReader => 
OldCombineFileRecordReader, CombineFileSplit => OldCombineFileSplit}
+import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader 
=> NewRecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => 
NewCombineFileInputFormat, CombineFileRecordReader => 
NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, FileSplit 
=> NewFileSplit, TextInputFormat => NewTextInputFormat}
 import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => 
NewTextOutputFormat}
 import org.scalatest.BeforeAndAfter
 
@@ -103,40 +97,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with 
SharedSparkContext
     assert(bytesRead2 == bytesRead)
   }
 
-  /**
-   * This checks the situation where we have interleaved reads from
-   * different sources. Currently, we only accumulate from the first
-   * read method we find in the task. This test uses cartesian to create
-   * the interleaved reads.
-   *
-   * Once https://issues.apache.org/jira/browse/SPARK-5225 is fixed
-   * this test should break.
-   */
-  test("input metrics with mixed read method") {
-    // prime the cache manager
-    val numPartitions = 2
-    val rdd = sc.parallelize(1 to 100, numPartitions).cache()
-    rdd.collect()
-
-    val rdd2 = sc.textFile(tmpFilePath, numPartitions)
-
-    val bytesRead = runAndReturnBytesRead {
-      rdd.count()
-    }
-    val bytesRead2 = runAndReturnBytesRead {
-      rdd2.count()
-    }
-
-    val cartRead = runAndReturnBytesRead {
-      rdd.cartesian(rdd2).count()
-    }
-
-    assert(cartRead != 0)
-    assert(bytesRead != 0)
-    // We read from the first rdd of the cartesian once per partition.
-    assert(cartRead == bytesRead * numPartitions)
-  }
-
   test("input metrics for new Hadoop API with coalesce") {
     val bytesRead = runAndReturnBytesRead {
       sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], 
classOf[LongWritable],
@@ -209,10 +169,10 @@ class InputOutputMetricsSuite extends SparkFunSuite with 
SharedSparkContext
     sc.addSparkListener(new SparkListener() {
       override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
         val metrics = taskEnd.taskMetrics
-        metrics.inputMetrics.foreach(inputRead += _.recordsRead)
-        metrics.outputMetrics.foreach(outputWritten += _.recordsWritten)
-        metrics.shuffleReadMetrics.foreach(shuffleRead += _.recordsRead)
-        metrics.shuffleWriteMetrics.foreach(shuffleWritten += _.recordsWritten)
+        inputRead += metrics.inputMetrics.recordsRead
+        outputWritten += metrics.outputMetrics.recordsWritten
+        shuffleRead += metrics.shuffleReadMetrics.recordsRead
+        shuffleWritten += metrics.shuffleWriteMetrics.recordsWritten
       }
     })
 
@@ -272,19 +232,18 @@ class InputOutputMetricsSuite extends SparkFunSuite with 
SharedSparkContext
   }
 
   private def runAndReturnBytesRead(job: => Unit): Long = {
-    runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.bytesRead))
+    runAndReturnMetrics(job, _.taskMetrics.inputMetrics.bytesRead)
   }
 
   private def runAndReturnRecordsRead(job: => Unit): Long = {
-    runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.recordsRead))
+    runAndReturnMetrics(job, _.taskMetrics.inputMetrics.recordsRead)
   }
 
   private def runAndReturnRecordsWritten(job: => Unit): Long = {
-    runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten))
+    runAndReturnMetrics(job, _.taskMetrics.outputMetrics.recordsWritten)
   }
 
-  private def runAndReturnMetrics(job: => Unit,
-      collector: (SparkListenerTaskEnd) => Option[Long]): Long = {
+  private def runAndReturnMetrics(job: => Unit, collector: 
(SparkListenerTaskEnd) => Long): Long = {
     val taskMetrics = new ArrayBuffer[Long]()
 
     // Avoid receiving earlier taskEnd events
@@ -292,7 +251,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with 
SharedSparkContext
 
     sc.addSparkListener(new SparkListener() {
       override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
-        collector(taskEnd).foreach(taskMetrics += _)
+        taskMetrics += collector(taskEnd)
       }
     })
 
@@ -337,7 +296,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with 
SharedSparkContext
       val taskBytesWritten = new ArrayBuffer[Long]()
       sc.addSparkListener(new SparkListener() {
         override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
-          taskBytesWritten += 
taskEnd.taskMetrics.outputMetrics.get.bytesWritten
+          taskBytesWritten += taskEnd.taskMetrics.outputMetrics.bytesWritten
         }
       })
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to