Repository: spark
Updated Branches:
  refs/heads/master f48273c13 -> 3476390c6


[SPARK-20715] Store MapStatuses only in MapOutputTracker, not ShuffleMapStage

## What changes were proposed in this pull request?

This PR refactors `ShuffleMapStage` and `MapOutputTracker` in order to simplify 
the management of `MapStatuses`, reduce driver memory consumption, and remove a 
potential source of scheduler correctness bugs.

### Background

In Spark there are currently two places where MapStatuses are tracked:

- The `MapOutputTracker` maintains an `Array[MapStatus]` storing a single 
location for each map output. This mapping is used by the `DAGScheduler` for 
determining reduce-task locality preferences (when locality-aware reduce task 
scheduling is enabled) and is also used to serve map output locations to 
executors / tasks.
- Each `ShuffleMapStage` also contains a mapping of `Array[List[MapStatus]]` 
which holds the complete set of locations where each map output could be 
available. This mapping is used to determine which map tasks need to be run 
when constructing `TaskSets` for the stage.

This duplication adds complexity and creates the potential for certain types of 
correctness bugs.  Bad things can happen if these two copies of the map output 
locations get out of sync. For instance, if the `MapOutputTracker` is missing 
locations for a map output but `ShuffleMapStage` believes that locations are 
available then tasks will fail with `MetadataFetchFailedException` but 
`ShuffleMapStage` will not be updated to reflect the missing map outputs, 
leading to situations where the stage will be reattempted (because downstream 
stages experienced fetch failures) but no task sets will be launched (because 
`ShuffleMapStage` thinks all maps are available).

I observed this behavior in a real-world deployment. I'm still not quite sure 
how the state got out of sync in the first place, but we can completely avoid 
this class of bug if we eliminate the duplicate state.

### Why we only need to track a single location for each map output

I think that storing an `Array[List[MapStatus]]` in `ShuffleMapStage` is 
unnecessary.

First, note that this adds memory/object bloat to the driver we need one extra 
`List` per task. If you have millions of tasks across all stages then this can 
add up to be a significant amount of resources.

Secondly, I believe that it's extremely uncommon that these lists will ever 
contain more than one entry. It's not impossible, but is very unlikely given 
the conditions which must occur for that to happen:

- In normal operation (no task failures) we'll only run each task once and thus 
will have at most one output.
- If speculation is enabled then it's possible that we'll have multiple 
attempts of a task. The TaskSetManager will [kill duplicate attempts of a 
task](https://github.com/apache/spark/blob/04901dd03a3f8062fd39ea38d585935ff71a9248/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L717)
 after a task finishes successfully, reducing the likelihood that both the 
original and speculated task will successfully register map outputs.
- There is a [comment in 
`TaskSetManager`](https://github.com/apache/spark/blob/04901dd03a3f8062fd39ea38d585935ff71a9248/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L113)
 which suggests that running tasks are not killed if a task set becomes a 
zombie. However:
  - If the task set becomes a zombie due to the job being cancelled then it 
doesn't matter whether we record map outputs.
  - If the task set became a zombie because of a stage failure (e.g. the map 
stage itself had a fetch failure from an upstream match stage) then I believe 
that the "failedEpoch" will be updated which may cause map outputs from 
still-running tasks to [be 
ignored](https://github.com/apache/spark/blob/04901dd03a3f8062fd39ea38d585935ff71a9248/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L1213).
 (I'm not 100% sure on this point, though).
- Even if you _do_ manage to record multiple map outputs for a stage, only a 
single map output is reported to / tracked by the MapOutputTracker. The only 
situation where the additional output locations could actually be read or used 
would be if a task experienced a `FetchFailure` exception. The most likely 
cause of a `FetchFailure` exception is an executor lost, which will have most 
likely caused the loss of several map tasks' output, so saving on potential 
re-execution of a single map task isn't a huge win if we're going to have to 
recompute several other lost map outputs from other tasks which ran on that 
lost executor. Also note that the re-population of MapOutputTracker state from 
state in the ShuffleMapTask only happens after the reduce stage has failed; the 
additional location doesn't help to prevent FetchFailures but, instead, can 
only reduce the amount of work when recomputing missing parent stages.

Given this, this patch chooses to do away with tracking multiple locations for 
map outputs and instead stores only a single location. This change removes the 
main distinction between the `ShuffleMapTask` and `MapOutputTracker`'s copies 
of this state, paving the way for storing it only in the `MapOutputTracker`.

### Overview of other changes

- Significantly simplified the cache / lock management inside of the 
`MapOutputTrackerMaster`:
  - The old code had several parallel `HashMap`s which had to be guarded by 
maps of `Object`s which were used as locks. This code was somewhat complicated 
to follow.
  - The new code uses a new `ShuffleStatus` class to group together all of the 
state associated with a particular shuffle, including cached serialized map 
statuses, significantly simplifying the logic.
- Moved more code out of the shared `MapOutputTracker` abstract base class and 
into the `MapOutputTrackerMaster` and `MapOutputTrackerWorker` subclasses. This 
makes it easier to reason about which functionality needs to be supported only 
on the driver or executor.
- Removed a bunch of code from the `DAGScheduler` which was used to synchronize 
information from the `MapOutputTracker` to `ShuffleMapStage`.
- Added comments to clarify the role of `MapOutputTrackerMaster`'s `epoch` in 
invalidating executor-side shuffle map output caches.

I will comment on these changes via inline GitHub review comments.

/cc hvanhovell and rxin (whom I discussed this with offline), tgravescs (who 
recently worked on caching of serialized MapOutputStatuses), and kayousterhout 
and markhamstra (for scheduler changes).

## How was this patch tested?

Existing tests. I purposely avoided making interface / API which would require 
significant updates or modifications to test code.

Author: Josh Rosen <joshro...@databricks.com>

Closes #17955 from JoshRosen/map-output-tracker-rewrite.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3476390c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3476390c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3476390c

Branch: refs/heads/master
Commit: 3476390c6e5d0fcfff340410f57e114039b5fbd4
Parents: f48273c
Author: Josh Rosen <joshro...@databricks.com>
Authored: Sun Jun 11 18:34:12 2017 -0700
Committer: Josh Rosen <joshro...@databricks.com>
Committed: Sun Jun 11 18:34:12 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/MapOutputTracker.scala     | 636 +++++++++++--------
 .../org/apache/spark/executor/Executor.scala    |  10 +-
 .../apache/spark/scheduler/DAGScheduler.scala   |  51 +-
 .../spark/scheduler/ShuffleMapStage.scala       |  76 +--
 .../spark/scheduler/TaskSchedulerImpl.scala     |   2 +-
 .../apache/spark/MapOutputTrackerSuite.scala    |   6 +-
 .../scala/org/apache/spark/ShuffleSuite.scala   |   3 +-
 .../spark/scheduler/BlacklistTrackerSuite.scala |   3 +-
 8 files changed, 398 insertions(+), 389 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3476390c/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 4ef6656..3e10b9e 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -34,6 +34,156 @@ import org.apache.spark.shuffle.MetadataFetchFailedException
 import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
 import org.apache.spark.util._
 
+/**
+ * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping 
for a single
+ * ShuffleMapStage.
+ *
+ * This class maintains a mapping from mapIds to `MapStatus`. It also 
maintains a cache of
+ * serialized map statuses in order to speed up tasks' requests for map output 
statuses.
+ *
+ * All public methods of this class are thread-safe.
+ */
+private class ShuffleStatus(numPartitions: Int) {
+
+  // All accesses to the following state must be guarded with 
`this.synchronized`.
+
+  /**
+   * MapStatus for each partition. The index of the array is the map partition 
id.
+   * Each value in the array is the MapStatus for a partition, or null if the 
partition
+   * is not available. Even though in theory a task may run multiple times 
(due to speculation,
+   * stage retries, etc.), in practice the likelihood of a map output being 
available at multiple
+   * locations is so small that we choose to ignore that case and store only a 
single location
+   * for each output.
+   */
+  private[this] val mapStatuses = new Array[MapStatus](numPartitions)
+
+  /**
+   * The cached result of serializing the map statuses array. This cache is 
lazily populated when
+   * [[serializedMapStatus]] is called. The cache is invalidated when map 
outputs are removed.
+   */
+  private[this] var cachedSerializedMapStatus: Array[Byte] = _
+
+  /**
+   * Broadcast variable holding serialized map output statuses array. When 
[[serializedMapStatus]]
+   * serializes the map statuses array it may detect that the result is too 
large to send in a
+   * single RPC, in which case it places the serialized array into a broadcast 
variable and then
+   * sends a serialized broadcast variable instead. This variable holds a 
reference to that
+   * broadcast variable in order to keep it from being garbage collected and 
to allow for it to be
+   * explicitly destroyed later on when the ShuffleMapStage is 
garbage-collected.
+   */
+  private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
+
+  /**
+   * Counter tracking the number of partitions that have output. This is a 
performance optimization
+   * to avoid having to count the number of non-null entries in the 
`mapStatuses` array and should
+   * be equivalent to`mapStatuses.count(_ ne null)`.
+   */
+  private[this] var _numAvailableOutputs: Int = 0
+
+  /**
+   * Register a map output. If there is already a registered location for the 
map output then it
+   * will be replaced by the new location.
+   */
+  def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
+    if (mapStatuses(mapId) == null) {
+      _numAvailableOutputs += 1
+      invalidateSerializedMapOutputStatusCache()
+    }
+    mapStatuses(mapId) = status
+  }
+
+  /**
+   * Remove the map output which was served by the specified block manager.
+   * This is a no-op if there is no registered map output or if the registered 
output is from a
+   * different block manager.
+   */
+  def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = 
synchronized {
+    if (mapStatuses(mapId) != null && mapStatuses(mapId).location == 
bmAddress) {
+      _numAvailableOutputs -= 1
+      mapStatuses(mapId) = null
+      invalidateSerializedMapOutputStatusCache()
+    }
+  }
+
+  /**
+   * Removes all map outputs associated with the specified executor. Note that 
this will also
+   * remove outputs which are served by an external shuffle server (if one 
exists), as they are
+   * still registered with that execId.
+   */
+  def removeOutputsOnExecutor(execId: String): Unit = synchronized {
+    for (mapId <- 0 until mapStatuses.length) {
+      if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId 
== execId) {
+        _numAvailableOutputs -= 1
+        mapStatuses(mapId) = null
+        invalidateSerializedMapOutputStatusCache()
+      }
+    }
+  }
+
+  /**
+   * Number of partitions that have shuffle outputs.
+   */
+  def numAvailableOutputs: Int = synchronized {
+    _numAvailableOutputs
+  }
+
+  /**
+   * Returns the sequence of partition ids that are missing (i.e. needs to be 
computed).
+   */
+  def findMissingPartitions(): Seq[Int] = synchronized {
+    val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
+    assert(missing.size == numPartitions - _numAvailableOutputs,
+      s"${missing.size} missing, expected ${numPartitions - 
_numAvailableOutputs}")
+    missing
+  }
+
+  /**
+   * Serializes the mapStatuses array into an efficient compressed format. See 
the comments on
+   * `MapOutputTracker.serializeMapStatuses()` for more details on the 
serialization format.
+   *
+   * This method is designed to be called multiple times and implements 
caching in order to speed
+   * up subsequent requests. If the cache is empty and multiple threads 
concurrently attempt to
+   * serialize the map statuses then serialization will only be performed in a 
single thread and all
+   * other threads will block until the cache is populated.
+   */
+  def serializedMapStatus(
+      broadcastManager: BroadcastManager,
+      isLocal: Boolean,
+      minBroadcastSize: Int): Array[Byte] = synchronized {
+    if (cachedSerializedMapStatus eq null) {
+      val serResult = MapOutputTracker.serializeMapStatuses(
+          mapStatuses, broadcastManager, isLocal, minBroadcastSize)
+      cachedSerializedMapStatus = serResult._1
+      cachedSerializedBroadcast = serResult._2
+    }
+    cachedSerializedMapStatus
+  }
+
+  // Used in testing.
+  def hasCachedSerializedBroadcast: Boolean = synchronized {
+    cachedSerializedBroadcast != null
+  }
+
+  /**
+   * Helper function which provides thread-safe access to the mapStatuses 
array.
+   * The function should NOT mutate the array.
+   */
+  def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized {
+    f(mapStatuses)
+  }
+
+  /**
+   * Clears the cached serialized map output statuses.
+   */
+  def invalidateSerializedMapOutputStatusCache(): Unit = synchronized {
+    if (cachedSerializedBroadcast != null) {
+      cachedSerializedBroadcast.destroy()
+      cachedSerializedBroadcast = null
+    }
+    cachedSerializedMapStatus = null
+  }
+}
+
 private[spark] sealed trait MapOutputTrackerMessage
 private[spark] case class GetMapOutputStatuses(shuffleId: Int)
   extends MapOutputTrackerMessage
@@ -62,37 +212,26 @@ private[spark] class MapOutputTrackerMasterEndpoint(
 }
 
 /**
- * Class that keeps track of the location of the map output of
- * a stage. This is abstract because different versions of MapOutputTracker
- * (driver and executor) use different HashMap to store its metadata.
- */
+ * Class that keeps track of the location of the map output of a stage. This 
is abstract because the
+ * driver and executor have different versions of the MapOutputTracker. In 
principle the driver-
+ * and executor-side classes don't need to share a common base class; the 
current shared base class
+ * is maintained primarily for backwards-compatibility in order to avoid 
having to update existing
+ * test code.
+*/
 private[spark] abstract class MapOutputTracker(conf: SparkConf) extends 
Logging {
-
   /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
   var trackerEndpoint: RpcEndpointRef = _
 
   /**
-   * This HashMap has different behavior for the driver and the executors.
-   *
-   * On the driver, it serves as the source of map outputs recorded from 
ShuffleMapTasks.
-   * On the executors, it simply serves as a cache, in which a miss triggers a 
fetch from the
-   * driver's corresponding HashMap.
-   *
-   * Note: because mapStatuses is accessed concurrently, subclasses should 
make sure it's a
-   * thread-safe map.
-   */
-  protected val mapStatuses: Map[Int, Array[MapStatus]]
-
-  /**
-   * Incremented every time a fetch fails so that client nodes know to clear
-   * their cache of map output locations if this happens.
+   * The driver-side counter is incremented every time that a map output is 
lost. This value is sent
+   * to executors as part of tasks, where executors compare the new epoch 
number to the highest
+   * epoch number that they received in the past. If the new epoch number is 
higher then executors
+   * will clear their local caches of map output statuses and will re-fetch 
(possibly updated)
+   * statuses from the driver.
    */
   protected var epoch: Long = 0
   protected val epochLock = new AnyRef
 
-  /** Remembers which map output locations are currently being fetched on an 
executor. */
-  private val fetching = new HashSet[Int]
-
   /**
    * Send a message to the trackerEndpoint and get its result within a default 
timeout, or
    * throw a SparkException if this fails.
@@ -116,14 +255,7 @@ private[spark] abstract class MapOutputTracker(conf: 
SparkConf) extends Logging
     }
   }
 
-  /**
-   * Called from executors to get the server URIs and output sizes for each 
shuffle block that
-   * needs to be read from a given reduce task.
-   *
-   * @return A sequence of 2-item tuples, where the first item in the tuple is 
a BlockManagerId,
-   *         and the second item is a sequence of (shuffle block id, shuffle 
block size) tuples
-   *         describing the shuffle blocks that are stored at that block 
manager.
-   */
+  // For testing
   def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
       : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
     getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
@@ -139,135 +271,31 @@ private[spark] abstract class MapOutputTracker(conf: 
SparkConf) extends Logging
    *         describing the shuffle blocks that are stored at that block 
manager.
    */
   def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, 
endPartition: Int)
-      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
-    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions 
$startPartition-$endPartition")
-    val statuses = getStatuses(shuffleId)
-    // Synchronize on the returned array because, on the driver, it gets 
mutated in place
-    statuses.synchronized {
-      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, 
endPartition, statuses)
-    }
-  }
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])]
 
   /**
-   * Return statistics about all of the outputs for a given shuffle.
+   * Deletes map output status information for the specified shuffle stage.
    */
-  def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
-    val statuses = getStatuses(dep.shuffleId)
-    // Synchronize on the returned array because, on the driver, it gets 
mutated in place
-    statuses.synchronized {
-      val totalSizes = new Array[Long](dep.partitioner.numPartitions)
-      for (s <- statuses) {
-        for (i <- 0 until totalSizes.length) {
-          totalSizes(i) += s.getSizeForBlock(i)
-        }
-      }
-      new MapOutputStatistics(dep.shuffleId, totalSizes)
-    }
-  }
-
-  /**
-   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: 
clients MUST synchronize
-   * on this array when reading it, because on the driver, we may be changing 
it in place.
-   *
-   * (It would be nice to remove this restriction in the future.)
-   */
-  private def getStatuses(shuffleId: Int): Array[MapStatus] = {
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching 
them")
-      val startTime = System.currentTimeMillis
-      var fetchedStatuses: Array[MapStatus] = null
-      fetching.synchronized {
-        // Someone else is fetching it; wait for them to be done
-        while (fetching.contains(shuffleId)) {
-          try {
-            fetching.wait()
-          } catch {
-            case e: InterruptedException =>
-          }
-        }
-
-        // Either while we waited the fetch happened successfully, or
-        // someone fetched it in between the get and the fetching.synchronized.
-        fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          // We have to do the fetch, get others to wait for us.
-          fetching += shuffleId
-        }
-      }
+  def unregisterShuffle(shuffleId: Int): Unit
 
-      if (fetchedStatuses == null) {
-        // We won the race to fetch the statuses; do so
-        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
-        // This try-finally prevents hangs due to timeouts:
-        try {
-          val fetchedBytes = 
askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          fetchedStatuses = 
MapOutputTracker.deserializeMapStatuses(fetchedBytes)
-          logInfo("Got the output locations")
-          mapStatuses.put(shuffleId, fetchedStatuses)
-        } finally {
-          fetching.synchronized {
-            fetching -= shuffleId
-            fetching.notifyAll()
-          }
-        }
-      }
-      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
-        s"${System.currentTimeMillis - startTime} ms")
-
-      if (fetchedStatuses != null) {
-        return fetchedStatuses
-      } else {
-        logError("Missing all output locations for shuffle " + shuffleId)
-        throw new MetadataFetchFailedException(
-          shuffleId, -1, "Missing all output locations for shuffle " + 
shuffleId)
-      }
-    } else {
-      return statuses
-    }
-  }
-
-  /** Called to get current epoch number. */
-  def getEpoch: Long = {
-    epochLock.synchronized {
-      return epoch
-    }
-  }
-
-  /**
-   * Called from executors to update the epoch number, potentially clearing 
old outputs
-   * because of a fetch failure. Each executor task calls this with the latest 
epoch
-   * number on the driver at the time it was created.
-   */
-  def updateEpoch(newEpoch: Long) {
-    epochLock.synchronized {
-      if (newEpoch > epoch) {
-        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
-        epoch = newEpoch
-        mapStatuses.clear()
-      }
-    }
-  }
-
-  /** Unregister shuffle data. */
-  def unregisterShuffle(shuffleId: Int) {
-    mapStatuses.remove(shuffleId)
-  }
-
-  /** Stop the tracker. */
-  def stop() { }
+  def stop() {}
 }
 
 /**
- * MapOutputTracker for the driver.
+ * Driver-side class that keeps track of the location of the map output of a 
stage.
+ *
+ * The DAGScheduler uses this class to (de)register map output statuses and to 
look up statistics
+ * for performing locality-aware reduce task scheduling.
+ *
+ * ShuffleMapStage uses this class for tracking available / missing outputs in 
order to determine
+ * which tasks need to be run.
  */
-private[spark] class MapOutputTrackerMaster(conf: SparkConf,
-    broadcastManager: BroadcastManager, isLocal: Boolean)
+private[spark] class MapOutputTrackerMaster(
+    conf: SparkConf,
+    broadcastManager: BroadcastManager,
+    isLocal: Boolean)
   extends MapOutputTracker(conf) {
 
-  /** Cache a serialized version of the output statuses for each shuffle to 
send them out faster */
-  private var cacheEpoch = epoch
-
   // The size at which we use Broadcast to send the map output statuses to the 
executors
   private val minSizeForBroadcast =
     conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", 
"512k").toInt
@@ -287,22 +315,12 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf,
   // can be read locally, but may lead to more delay in scheduling if those 
locations are busy.
   private val REDUCER_PREF_LOCS_FRACTION = 0.2
 
-  // HashMaps for storing mapStatuses and cached serialized statuses in the 
driver.
+  // HashMap for storing shuffleStatuses in the driver.
   // Statuses are dropped only by explicit de-registering.
-  protected val mapStatuses = new ConcurrentHashMap[Int, 
Array[MapStatus]]().asScala
-  private val cachedSerializedStatuses = new ConcurrentHashMap[Int, 
Array[Byte]]().asScala
+  private val shuffleStatuses = new ConcurrentHashMap[Int, 
ShuffleStatus]().asScala
 
   private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
 
-  // Kept in sync with cachedSerializedStatuses explicitly
-  // This is required so that the Broadcast variable remains in scope until we 
remove
-  // the shuffleId explicitly or implicitly.
-  private val cachedSerializedBroadcast = new HashMap[Int, 
Broadcast[Array[Byte]]]()
-
-  // This is to prevent multiple serializations of the same shuffle - which 
happens when
-  // there is a request storm when shuffle start.
-  private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]()
-
   // requests for map output statuses
   private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
 
@@ -348,8 +366,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
             val hostPort = context.senderAddress.hostPort
             logDebug("Handling request to send map output locations for 
shuffle " + shuffleId +
               " to " + hostPort)
-            val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
-            context.reply(mapOutputStatuses)
+            val shuffleStatus = shuffleStatuses.get(shuffleId).head
+            context.reply(
+              shuffleStatus.serializedMapStatus(broadcastManager, isLocal, 
minSizeForBroadcast))
           } catch {
             case NonFatal(e) => logError(e.getMessage, e)
           }
@@ -363,59 +382,77 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf,
   /** A poison endpoint that indicates MessageLoop should exit its message 
loop. */
   private val PoisonPill = new GetMapOutputMessage(-99, null)
 
-  // Exposed for testing
-  private[spark] def getNumCachedSerializedBroadcast = 
cachedSerializedBroadcast.size
+  // Used only in unit tests.
+  private[spark] def getNumCachedSerializedBroadcast: Int = {
+    shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
+  }
 
   def registerShuffle(shuffleId: Int, numMaps: Int) {
-    if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+    if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
       throw new IllegalArgumentException("Shuffle ID " + shuffleId + " 
registered twice")
     }
-    // add in advance
-    shuffleIdLocks.putIfAbsent(shuffleId, new Object())
   }
 
   def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
-    val array = mapStatuses(shuffleId)
-    array.synchronized {
-      array(mapId) = status
-    }
-  }
-
-  /** Register multiple map output information for the given shuffle */
-  def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], 
changeEpoch: Boolean = false) {
-    mapStatuses.put(shuffleId, statuses.clone())
-    if (changeEpoch) {
-      incrementEpoch()
-    }
+    shuffleStatuses(shuffleId).addMapOutput(mapId, status)
   }
 
   /** Unregister map output information of the given shuffle, mapper and block 
manager */
   def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: 
BlockManagerId) {
-    val arrayOpt = mapStatuses.get(shuffleId)
-    if (arrayOpt.isDefined && arrayOpt.get != null) {
-      val array = arrayOpt.get
-      array.synchronized {
-        if (array(mapId) != null && array(mapId).location == bmAddress) {
-          array(mapId) = null
-        }
-      }
-      incrementEpoch()
-    } else {
-      throw new SparkException("unregisterMapOutput called for nonexistent 
shuffle ID")
+    shuffleStatuses.get(shuffleId) match {
+      case Some(shuffleStatus) =>
+        shuffleStatus.removeMapOutput(mapId, bmAddress)
+        incrementEpoch()
+      case None =>
+        throw new SparkException("unregisterMapOutput called for nonexistent 
shuffle ID")
     }
   }
 
   /** Unregister shuffle data */
-  override def unregisterShuffle(shuffleId: Int) {
-    mapStatuses.remove(shuffleId)
-    cachedSerializedStatuses.remove(shuffleId)
-    cachedSerializedBroadcast.remove(shuffleId).foreach(v => 
removeBroadcast(v))
-    shuffleIdLocks.remove(shuffleId)
+  def unregisterShuffle(shuffleId: Int) {
+    shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
+      shuffleStatus.invalidateSerializedMapOutputStatusCache()
+    }
+  }
+
+  /**
+   * Removes all shuffle outputs associated with this executor. Note that this 
will also remove
+   * outputs which are served by an external shuffle server (if one exists), 
as they are still
+   * registered with this execId.
+   */
+  def removeOutputsOnExecutor(execId: String): Unit = {
+    shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) 
}
+    incrementEpoch()
   }
 
   /** Check if the given shuffle is being tracked */
-  def containsShuffle(shuffleId: Int): Boolean = {
-    cachedSerializedStatuses.contains(shuffleId) || 
mapStatuses.contains(shuffleId)
+  def containsShuffle(shuffleId: Int): Boolean = 
shuffleStatuses.contains(shuffleId)
+
+  def getNumAvailableOutputs(shuffleId: Int): Int = {
+    shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0)
+  }
+
+  /**
+   * Returns the sequence of partition ids that are missing (i.e. needs to be 
computed), or None
+   * if the MapOutputTrackerMaster doesn't know about this shuffle.
+   */
+  def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = {
+    shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
+  }
+
+  /**
+   * Return statistics about all of the outputs for a given shuffle.
+   */
+  def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
+    shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
+      val totalSizes = new Array[Long](dep.partitioner.numPartitions)
+      for (s <- statuses) {
+        for (i <- 0 until totalSizes.length) {
+          totalSizes(i) += s.getSizeForBlock(i)
+        }
+      }
+      new MapOutputStatistics(dep.shuffleId, totalSizes)
+    }
   }
 
   /**
@@ -459,9 +496,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
       fractionThreshold: Double)
     : Option[Array[BlockManagerId]] = {
 
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses != null) {
-      statuses.synchronized {
+    val shuffleStatus = shuffleStatuses.get(shuffleId).orNull
+    if (shuffleStatus != null) {
+      shuffleStatus.withMapStatuses { statuses =>
         if (statuses.nonEmpty) {
           // HashMap to add up sizes of all blocks at the same location
           val locs = new HashMap[BlockManagerId, Long]
@@ -502,77 +539,24 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf,
     }
   }
 
-  private def removeBroadcast(bcast: Broadcast[_]): Unit = {
-    if (null != bcast) {
-      broadcastManager.unbroadcast(bcast.id,
-        removeFromDriver = true, blocking = false)
+  /** Called to get current epoch number. */
+  def getEpoch: Long = {
+    epochLock.synchronized {
+      return epoch
     }
   }
 
-  private def clearCachedBroadcast(): Unit = {
-    for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2)
-    cachedSerializedBroadcast.clear()
-  }
-
-  def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
-    var statuses: Array[MapStatus] = null
-    var retBytes: Array[Byte] = null
-    var epochGotten: Long = -1
-
-    // Check to see if we have a cached version, returns true if it does
-    // and has side effect of setting retBytes.  If not returns false
-    // with side effect of setting statuses
-    def checkCachedStatuses(): Boolean = {
-      epochLock.synchronized {
-        if (epoch > cacheEpoch) {
-          cachedSerializedStatuses.clear()
-          clearCachedBroadcast()
-          cacheEpoch = epoch
-        }
-        cachedSerializedStatuses.get(shuffleId) match {
-          case Some(bytes) =>
-            retBytes = bytes
-            true
-          case None =>
-            logDebug("cached status not found for : " + shuffleId)
-            statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus])
-            epochGotten = epoch
-            false
-        }
-      }
-    }
-
-    if (checkCachedStatuses()) return retBytes
-    var shuffleIdLock = shuffleIdLocks.get(shuffleId)
-    if (null == shuffleIdLock) {
-      val newLock = new Object()
-      // in general, this condition should be false - but good to be paranoid
-      val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
-      shuffleIdLock = if (null != prevLock) prevLock else newLock
-    }
-    // synchronize so we only serialize/broadcast it once since multiple 
threads call
-    // in parallel
-    shuffleIdLock.synchronized {
-      // double check to make sure someone else didn't serialize and cache the 
same
-      // mapstatus while we were waiting on the synchronize
-      if (checkCachedStatuses()) return retBytes
-
-      // If we got here, we failed to find the serialized locations in the 
cache, so we pulled
-      // out a snapshot of the locations as "statuses"; let's serialize and 
return that
-      val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, 
broadcastManager,
-        isLocal, minSizeForBroadcast)
-      logInfo("Size of output statuses for shuffle %d is %d 
bytes".format(shuffleId, bytes.length))
-      // Add them into the table only if the epoch hasn't changed while we 
were working
-      epochLock.synchronized {
-        if (epoch == epochGotten) {
-          cachedSerializedStatuses(shuffleId) = bytes
-          if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
-        } else {
-          logInfo("Epoch changed, not caching!")
-          removeBroadcast(bcast)
+  // This method is only called in local-mode.
+  def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, 
endPartition: Int)
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions 
$startPartition-$endPartition")
+    shuffleStatuses.get(shuffleId) match {
+      case Some (shuffleStatus) =>
+        shuffleStatus.withMapStatuses { statuses =>
+          MapOutputTracker.convertMapStatuses(shuffleId, startPartition, 
endPartition, statuses)
         }
-      }
-      bytes
+      case None =>
+        Seq.empty
     }
   }
 
@@ -580,21 +564,121 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf,
     mapOutputRequests.offer(PoisonPill)
     threadpool.shutdown()
     sendTracker(StopMapOutputTracker)
-    mapStatuses.clear()
     trackerEndpoint = null
-    cachedSerializedStatuses.clear()
-    clearCachedBroadcast()
-    shuffleIdLocks.clear()
+    shuffleStatuses.clear()
   }
 }
 
 /**
- * MapOutputTracker for the executors, which fetches map output information 
from the driver's
- * MapOutputTrackerMaster.
+ * Executor-side client for fetching map output info from the driver's 
MapOutputTrackerMaster.
+ * Note that this is not used in local-mode; instead, local-mode Executors 
access the
+ * MapOutputTrackerMaster directly (which is possible because the master and 
worker share a comon
+ * superclass).
  */
 private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends 
MapOutputTracker(conf) {
-  protected val mapStatuses: Map[Int, Array[MapStatus]] =
+
+  val mapStatuses: Map[Int, Array[MapStatus]] =
     new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
+
+  /** Remembers which map output locations are currently being fetched on an 
executor. */
+  private val fetching = new HashSet[Int]
+
+  override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, 
endPartition: Int)
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions 
$startPartition-$endPartition")
+    val statuses = getStatuses(shuffleId)
+    try {
+      MapOutputTracker.convertMapStatuses(shuffleId, startPartition, 
endPartition, statuses)
+    } catch {
+      case e: MetadataFetchFailedException =>
+        // We experienced a fetch failure so our mapStatuses cache is 
outdated; clear it:
+        mapStatuses.clear()
+        throw e
+    }
+  }
+
+  /**
+   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: 
clients MUST synchronize
+   * on this array when reading it, because on the driver, we may be changing 
it in place.
+   *
+   * (It would be nice to remove this restriction in the future.)
+   */
+  private def getStatuses(shuffleId: Int): Array[MapStatus] = {
+    val statuses = mapStatuses.get(shuffleId).orNull
+    if (statuses == null) {
+      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching 
them")
+      val startTime = System.currentTimeMillis
+      var fetchedStatuses: Array[MapStatus] = null
+      fetching.synchronized {
+        // Someone else is fetching it; wait for them to be done
+        while (fetching.contains(shuffleId)) {
+          try {
+            fetching.wait()
+          } catch {
+            case e: InterruptedException =>
+          }
+        }
+
+        // Either while we waited the fetch happened successfully, or
+        // someone fetched it in between the get and the fetching.synchronized.
+        fetchedStatuses = mapStatuses.get(shuffleId).orNull
+        if (fetchedStatuses == null) {
+          // We have to do the fetch, get others to wait for us.
+          fetching += shuffleId
+        }
+      }
+
+      if (fetchedStatuses == null) {
+        // We won the race to fetch the statuses; do so
+        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+        // This try-finally prevents hangs due to timeouts:
+        try {
+          val fetchedBytes = 
askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
+          fetchedStatuses = 
MapOutputTracker.deserializeMapStatuses(fetchedBytes)
+          logInfo("Got the output locations")
+          mapStatuses.put(shuffleId, fetchedStatuses)
+        } finally {
+          fetching.synchronized {
+            fetching -= shuffleId
+            fetching.notifyAll()
+          }
+        }
+      }
+      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
+        s"${System.currentTimeMillis - startTime} ms")
+
+      if (fetchedStatuses != null) {
+        fetchedStatuses
+      } else {
+        logError("Missing all output locations for shuffle " + shuffleId)
+        throw new MetadataFetchFailedException(
+          shuffleId, -1, "Missing all output locations for shuffle " + 
shuffleId)
+      }
+    } else {
+      statuses
+    }
+  }
+
+
+  /** Unregister shuffle data. */
+  def unregisterShuffle(shuffleId: Int): Unit = {
+    mapStatuses.remove(shuffleId)
+  }
+
+  /**
+   * Called from executors to update the epoch number, potentially clearing 
old outputs
+   * because of a fetch failure. Each executor task calls this with the latest 
epoch
+   * number on the driver at the time it was created.
+   */
+  def updateEpoch(newEpoch: Long): Unit = {
+    epochLock.synchronized {
+      if (newEpoch > epoch) {
+        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
+        epoch = newEpoch
+        mapStatuses.clear()
+      }
+    }
+  }
 }
 
 private[spark] object MapOutputTracker extends Logging {
@@ -683,7 +767,7 @@ private[spark] object MapOutputTracker extends Logging {
    *         and the second item is a sequence of (shuffle block ID, shuffle 
block size) tuples
    *         describing the shuffle blocks that are stored at that block 
manager.
    */
-  private def convertMapStatuses(
+  def convertMapStatuses(
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,

http://git-wip-us.apache.org/repos/asf/spark/blob/3476390c/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 5b39668..19e7eb0 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -322,8 +322,14 @@ private[spark] class Executor(
           throw new TaskKilledException(killReason.get)
         }
 
-        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
-        env.mapOutputTracker.updateEpoch(task.epoch)
+        // The purpose of updating the epoch here is to invalidate executor 
map output status cache
+        // in case FetchFailures have occurred. In local mode 
`env.mapOutputTracker` will be
+        // MapOutputTrackerMaster and its cache invalidation is not based on 
epoch numbers so
+        // we don't need to make any special calls here.
+        if (!isLocal) {
+          logDebug("Task " + taskId + "'s epoch is " + task.epoch)
+          
env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
+        }
 
         // Run the actual task and measure its runtime.
         taskStart = System.currentTimeMillis()

http://git-wip-us.apache.org/repos/asf/spark/blob/3476390c/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ab2255f..932e6c1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -328,25 +328,14 @@ class DAGScheduler(
     val numTasks = rdd.partitions.length
     val parents = getOrCreateParentStages(rdd, jobId)
     val id = nextStageId.getAndIncrement()
-    val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, 
rdd.creationSite, shuffleDep)
+    val stage = new ShuffleMapStage(
+      id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, 
mapOutputTracker)
 
     stageIdToStage(id) = stage
     shuffleIdToMapStage(shuffleDep.shuffleId) = stage
     updateJobIdStageIdMaps(jobId, stage)
 
-    if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
-      // A previously run stage generated partitions for this shuffle, so for 
each output
-      // that's still available, copy information about that output location 
to the new stage
-      // (so we don't unnecessarily re-compute that data).
-      val serLocs = 
mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
-      val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
-      (0 until locs.length).foreach { i =>
-        if (locs(i) ne null) {
-          // locs(i) will be null if missing
-          stage.addOutputLoc(i, locs(i))
-        }
-      }
-    } else {
+    if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
       // Kind of ugly: need to register RDDs with the cache and map output 
tracker here
       // since we can't do it in the RDD constructor because # of partitions 
is unknown
       logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
@@ -1217,7 +1206,8 @@ class DAGScheduler(
               // The epoch of the task is acceptable (i.e., the task was 
launched after the most
               // recent failure we're aware of for the executor), so mark the 
task's output as
               // available.
-              shuffleStage.addOutputLoc(smt.partitionId, status)
+              mapOutputTracker.registerMapOutput(
+                shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
               // Remove the task's partition from pending partitions. This may 
have already been
               // done above, but will not have been done yet in cases where 
the task attempt was
               // from an earlier attempt of the stage (i.e., not the attempt 
that's currently
@@ -1234,16 +1224,14 @@ class DAGScheduler(
               logInfo("waiting: " + waitingStages)
               logInfo("failed: " + failedStages)
 
-              // We supply true to increment the epoch number here in case 
this is a
-              // recomputation of the map outputs. In that case, some nodes 
may have cached
-              // locations with holes (from when we detected the error) and 
will need the
-              // epoch incremented to refetch them.
-              // TODO: Only increment the epoch number if this is not the 
first time
-              //       we registered these map outputs.
-              mapOutputTracker.registerMapOutputs(
-                shuffleStage.shuffleDep.shuffleId,
-                shuffleStage.outputLocInMapOutputTrackerFormat(),
-                changeEpoch = true)
+              // This call to increment the epoch may not be strictly 
necessary, but it is retained
+              // for now in order to minimize the changes in behavior from an 
earlier version of the
+              // code. This existing behavior of always incrementing the epoch 
following any
+              // successful shuffle map stage completion may have benefits by 
causing unneeded
+              // cached map outputs to be cleaned up earlier on executors. In 
the future we can
+              // consider removing this call, but this will require some extra 
investigation.
+              // See 
https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
+              mapOutputTracker.incrementEpoch()
 
               clearCacheLocs()
 
@@ -1343,7 +1331,6 @@ class DAGScheduler(
           }
           // Mark the map whose fetch failed as broken in the map stage
           if (mapId != -1) {
-            mapStage.removeOutputLoc(mapId, bmAddress)
             mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
           }
 
@@ -1393,17 +1380,7 @@ class DAGScheduler(
 
       if (filesLost || !env.blockManager.externalShuffleServiceEnabled) {
         logInfo("Shuffle files lost for executor: %s (epoch 
%d)".format(execId, currentEpoch))
-        // TODO: This will be really slow if we keep accumulating shuffle map 
stages
-        for ((shuffleId, stage) <- shuffleIdToMapStage) {
-          stage.removeOutputsOnExecutor(execId)
-          mapOutputTracker.registerMapOutputs(
-            shuffleId,
-            stage.outputLocInMapOutputTrackerFormat(),
-            changeEpoch = true)
-        }
-        if (shuffleIdToMapStage.isEmpty) {
-          mapOutputTracker.incrementEpoch()
-        }
+        mapOutputTracker.removeOutputsOnExecutor(execId)
         clearCacheLocs()
       }
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/3476390c/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala 
b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index db4d9ef..05f650f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -19,9 +19,8 @@ package org.apache.spark.scheduler
 
 import scala.collection.mutable.HashSet
 
-import org.apache.spark.ShuffleDependency
+import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.BlockManagerId
 import org.apache.spark.util.CallSite
 
 /**
@@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage(
     parents: List[Stage],
     firstJobId: Int,
     callSite: CallSite,
-    val shuffleDep: ShuffleDependency[_, _, _])
+    val shuffleDep: ShuffleDependency[_, _, _],
+    mapOutputTrackerMaster: MapOutputTrackerMaster)
   extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) {
 
   private[this] var _mapStageJobs: List[ActiveJob] = Nil
 
-  private[this] var _numAvailableOutputs: Int = 0
-
   /**
    * Partitions that either haven't yet been computed, or that were computed 
on an executor
    * that has since been lost, so should be re-computed.  This variable is 
used by the
@@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage(
    */
   val pendingPartitions = new HashSet[Int]
 
-  /**
-   * List of [[MapStatus]] for each partition. The index of the array is the 
map partition id,
-   * and each value in the array is the list of possible [[MapStatus]] for a 
partition
-   * (a single task might run multiple times).
-   */
-  private[this] val outputLocs = 
Array.fill[List[MapStatus]](numPartitions)(Nil)
-
   override def toString: String = "ShuffleMapStage " + id
 
   /**
@@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage(
   /**
    * Number of partitions that have shuffle outputs.
    * When this reaches [[numPartitions]], this map stage is ready.
-   * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
    */
-  def numAvailableOutputs: Int = _numAvailableOutputs
+  def numAvailableOutputs: Int = 
mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId)
 
   /**
    * Returns true if the map stage is ready, i.e. all partitions have shuffle 
outputs.
-   * This should be the same as `outputLocs.contains(Nil)`.
    */
-  def isAvailable: Boolean = _numAvailableOutputs == numPartitions
+  def isAvailable: Boolean = numAvailableOutputs == numPartitions
 
   /** Returns the sequence of partition ids that are missing (i.e. needs to be 
computed). */
   override def findMissingPartitions(): Seq[Int] = {
-    val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
-    assert(missing.size == numPartitions - _numAvailableOutputs,
-      s"${missing.size} missing, expected ${numPartitions - 
_numAvailableOutputs}")
-    missing
-  }
-
-  def addOutputLoc(partition: Int, status: MapStatus): Unit = {
-    val prevList = outputLocs(partition)
-    outputLocs(partition) = status :: prevList
-    if (prevList == Nil) {
-      _numAvailableOutputs += 1
-    }
-  }
-
-  def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = {
-    val prevList = outputLocs(partition)
-    val newList = prevList.filterNot(_.location == bmAddress)
-    outputLocs(partition) = newList
-    if (prevList != Nil && newList == Nil) {
-      _numAvailableOutputs -= 1
-    }
-  }
-
-  /**
-   * Returns an array of [[MapStatus]] (index by partition id). For each 
partition, the returned
-   * value contains only one (i.e. the first) [[MapStatus]]. If there is no 
entry for the partition,
-   * that position is filled with null.
-   */
-  def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = {
-    outputLocs.map(_.headOption.orNull)
-  }
-
-  /**
-   * Removes all shuffle outputs associated with this executor. Note that this 
will also remove
-   * outputs which are served by an external shuffle server (if one exists), 
as they are still
-   * registered with this execId.
-   */
-  def removeOutputsOnExecutor(execId: String): Unit = {
-    var becameUnavailable = false
-    for (partition <- 0 until numPartitions) {
-      val prevList = outputLocs(partition)
-      val newList = prevList.filterNot(_.location.executorId == execId)
-      outputLocs(partition) = newList
-      if (prevList != Nil && newList == Nil) {
-        becameUnavailable = true
-        _numAvailableOutputs -= 1
-      }
-    }
-    if (becameUnavailable) {
-      logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
-        this, execId, _numAvailableOutputs, numPartitions, isAvailable))
-    }
+    mapOutputTrackerMaster
+      .findMissingPartitions(shuffleDep.shuffleId)
+      .getOrElse(0 until numPartitions)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3476390c/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index f3033e2..629cfc7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
 
   var backend: SchedulerBackend = null
 
-  val mapOutputTracker = SparkEnv.get.mapOutputTracker
+  val mapOutputTracker = 
SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
 
   private var schedulableBuilder: SchedulableBuilder = null
   // default scheduler is FIFO

http://git-wip-us.apache.org/repos/asf/spark/blob/3476390c/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala 
b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 4fe5c5e..bc3d23e 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -139,21 +139,21 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       slaveRpcEnv.setupEndpointRef(rpcEnv.address, 
MapOutputTracker.ENDPOINT_NAME)
 
     masterTracker.registerShuffle(10, 1)
-    masterTracker.incrementEpoch()
     slaveTracker.updateEpoch(masterTracker.getEpoch)
+    // This is expected to fail because no outputs have been registered for 
the shuffle.
     intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 
0) }
 
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
     masterTracker.registerMapOutput(10, 0, MapStatus(
       BlockManagerId("a", "hostA", 1000), Array(1000L)))
-    masterTracker.incrementEpoch()
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
       Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 
0, 0), size1000)))))
     assert(0 == masterTracker.getNumCachedSerializedBroadcast)
 
+    val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch
     masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 
1000))
-    masterTracker.incrementEpoch()
+    assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput)
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 
0) }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3476390c/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala 
b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 622f798..3931d53 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -359,6 +359,7 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
     val shuffleMapRdd = new MyRDD(sc, 1, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(1))
     val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+    mapTrackerMaster.registerShuffle(0, 1)
 
     // first attempt -- its successful
     val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
@@ -393,7 +394,7 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
 
     // register one of the map outputs -- doesn't matter which one
     mapOutput1.foreach { case mapStatus =>
-      mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
+      mapTrackerMaster.registerMapOutput(0, 0, mapStatus)
     }
 
     val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,

http://git-wip-us.apache.org/repos/asf/spark/blob/3476390c/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
index 2b18ebe..571c6bb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
@@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with 
BeforeAndAfterEach with M
     sc = new SparkContext(conf)
     val scheduler = mock[TaskSchedulerImpl]
     when(scheduler.sc).thenReturn(sc)
-    when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker)
+    when(scheduler.mapOutputTracker).thenReturn(
+      SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])
     scheduler
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to