Repository: spark
Updated Branches:
  refs/heads/branch-2.2 af41dedc6 -> 3158fc3ce


[SPARK-23243][SPARK-20715][CORE][2.2] Fix RDD.repartition() data correctness 
issue

## What changes were proposed in this pull request?

Back port of #22354 and #17955 to 2.2 (#22354 depends on methods introduced by 
#17955).

-------

An alternative fix for #21698

When Spark rerun tasks for an RDD, there are 3 different behaviors:
1. determinate. Always return the same result with same order when rerun.
2. unordered. Returns same data set in random order when rerun.
3. indeterminate. Returns different result when rerun.

Normally Spark doesn't need to care about it. Spark runs stages one by one, 
when a task is failed, just rerun it. Although the rerun task may return a 
different result, users will not be surprised.

However, Spark may rerun a finished stage when seeing fetch failures. When this 
happens, Spark needs to rerun all the tasks of all the succeeding stages if the 
RDD output is indeterminate, because the input of the succeeding stages has 
been changed.

If the RDD output is determinate, we only need to rerun the failed tasks of the 
succeeding stages, because the input doesn't change.

If the RDD output is unordered, it's same as determinate, because shuffle 
partitioner is always deterministic(round-robin partitioner is not a shuffle 
partitioner that extends `org.apache.spark.Partitioner`), so the reducers will 
still get the same input data set.

This PR fixed the failure handling for `repartition`, to avoid correctness 
issues.

For `repartition`, it applies a stateful map function to generate a round-robin 
id, which is order sensitive and makes the RDD's output indeterminate. When the 
stage contains `repartition` reruns, we must also rerun all the tasks of all 
the succeeding stages.

**future improvement:**
1. Currently we can't rollback and rerun a shuffle map stage, and just fail. We 
should fix it later. https://issues.apache.org/jira/browse/SPARK-25341
2. Currently we can't rollback and rerun a result stage, and just fail. We 
should fix it later. https://issues.apache.org/jira/browse/SPARK-25342
3. We should provide public API to allow users to tag the random level of the 
RDD's computing function.

## How was this patch tested?

a new test case

Closes #22382 from bersprockets/SPARK-23243-2.2.

Lead-authored-by: Bruce Robbins <bersprock...@gmail.com>
Co-authored-by: Josh Rosen <joshro...@databricks.com>
Co-authored-by: Wenchen Fan <wenc...@databricks.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3158fc3c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3158fc3c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3158fc3c

Branch: refs/heads/branch-2.2
Commit: 3158fc3ce390f96d8c65d70bcdf9ac9aa26be24b
Parents: af41ded
Author: Bruce Robbins <bersprock...@gmail.com>
Authored: Tue Sep 11 12:06:19 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Sep 11 12:06:19 2018 +0800

----------------------------------------------------------------------
 .../org/apache/spark/MapOutputTracker.scala     | 636 +++++++++++--------
 .../scala/org/apache/spark/Partitioner.scala    |   3 +
 .../org/apache/spark/executor/Executor.scala    |  10 +-
 .../org/apache/spark/rdd/MapPartitionsRDD.scala |  21 +-
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 100 ++-
 .../apache/spark/scheduler/DAGScheduler.scala   | 110 ++--
 .../spark/scheduler/ShuffleMapStage.scala       |  76 +--
 .../spark/scheduler/TaskSchedulerImpl.scala     |   2 +-
 .../apache/spark/MapOutputTrackerSuite.scala    |   6 +-
 .../scala/org/apache/spark/ShuffleSuite.scala   |   3 +-
 .../spark/scheduler/BlacklistTrackerSuite.scala |   3 +-
 .../spark/scheduler/DAGSchedulerSuite.scala     | 169 ++++-
 .../execution/exchange/ShuffleExchange.scala    |  17 +-
 13 files changed, 750 insertions(+), 406 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 4ef6656..3e10b9e 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -34,6 +34,156 @@ import org.apache.spark.shuffle.MetadataFetchFailedException
 import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
 import org.apache.spark.util._
 
+/**
+ * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping 
for a single
+ * ShuffleMapStage.
+ *
+ * This class maintains a mapping from mapIds to `MapStatus`. It also 
maintains a cache of
+ * serialized map statuses in order to speed up tasks' requests for map output 
statuses.
+ *
+ * All public methods of this class are thread-safe.
+ */
+private class ShuffleStatus(numPartitions: Int) {
+
+  // All accesses to the following state must be guarded with 
`this.synchronized`.
+
+  /**
+   * MapStatus for each partition. The index of the array is the map partition 
id.
+   * Each value in the array is the MapStatus for a partition, or null if the 
partition
+   * is not available. Even though in theory a task may run multiple times 
(due to speculation,
+   * stage retries, etc.), in practice the likelihood of a map output being 
available at multiple
+   * locations is so small that we choose to ignore that case and store only a 
single location
+   * for each output.
+   */
+  private[this] val mapStatuses = new Array[MapStatus](numPartitions)
+
+  /**
+   * The cached result of serializing the map statuses array. This cache is 
lazily populated when
+   * [[serializedMapStatus]] is called. The cache is invalidated when map 
outputs are removed.
+   */
+  private[this] var cachedSerializedMapStatus: Array[Byte] = _
+
+  /**
+   * Broadcast variable holding serialized map output statuses array. When 
[[serializedMapStatus]]
+   * serializes the map statuses array it may detect that the result is too 
large to send in a
+   * single RPC, in which case it places the serialized array into a broadcast 
variable and then
+   * sends a serialized broadcast variable instead. This variable holds a 
reference to that
+   * broadcast variable in order to keep it from being garbage collected and 
to allow for it to be
+   * explicitly destroyed later on when the ShuffleMapStage is 
garbage-collected.
+   */
+  private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
+
+  /**
+   * Counter tracking the number of partitions that have output. This is a 
performance optimization
+   * to avoid having to count the number of non-null entries in the 
`mapStatuses` array and should
+   * be equivalent to`mapStatuses.count(_ ne null)`.
+   */
+  private[this] var _numAvailableOutputs: Int = 0
+
+  /**
+   * Register a map output. If there is already a registered location for the 
map output then it
+   * will be replaced by the new location.
+   */
+  def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
+    if (mapStatuses(mapId) == null) {
+      _numAvailableOutputs += 1
+      invalidateSerializedMapOutputStatusCache()
+    }
+    mapStatuses(mapId) = status
+  }
+
+  /**
+   * Remove the map output which was served by the specified block manager.
+   * This is a no-op if there is no registered map output or if the registered 
output is from a
+   * different block manager.
+   */
+  def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = 
synchronized {
+    if (mapStatuses(mapId) != null && mapStatuses(mapId).location == 
bmAddress) {
+      _numAvailableOutputs -= 1
+      mapStatuses(mapId) = null
+      invalidateSerializedMapOutputStatusCache()
+    }
+  }
+
+  /**
+   * Removes all map outputs associated with the specified executor. Note that 
this will also
+   * remove outputs which are served by an external shuffle server (if one 
exists), as they are
+   * still registered with that execId.
+   */
+  def removeOutputsOnExecutor(execId: String): Unit = synchronized {
+    for (mapId <- 0 until mapStatuses.length) {
+      if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId 
== execId) {
+        _numAvailableOutputs -= 1
+        mapStatuses(mapId) = null
+        invalidateSerializedMapOutputStatusCache()
+      }
+    }
+  }
+
+  /**
+   * Number of partitions that have shuffle outputs.
+   */
+  def numAvailableOutputs: Int = synchronized {
+    _numAvailableOutputs
+  }
+
+  /**
+   * Returns the sequence of partition ids that are missing (i.e. needs to be 
computed).
+   */
+  def findMissingPartitions(): Seq[Int] = synchronized {
+    val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
+    assert(missing.size == numPartitions - _numAvailableOutputs,
+      s"${missing.size} missing, expected ${numPartitions - 
_numAvailableOutputs}")
+    missing
+  }
+
+  /**
+   * Serializes the mapStatuses array into an efficient compressed format. See 
the comments on
+   * `MapOutputTracker.serializeMapStatuses()` for more details on the 
serialization format.
+   *
+   * This method is designed to be called multiple times and implements 
caching in order to speed
+   * up subsequent requests. If the cache is empty and multiple threads 
concurrently attempt to
+   * serialize the map statuses then serialization will only be performed in a 
single thread and all
+   * other threads will block until the cache is populated.
+   */
+  def serializedMapStatus(
+      broadcastManager: BroadcastManager,
+      isLocal: Boolean,
+      minBroadcastSize: Int): Array[Byte] = synchronized {
+    if (cachedSerializedMapStatus eq null) {
+      val serResult = MapOutputTracker.serializeMapStatuses(
+          mapStatuses, broadcastManager, isLocal, minBroadcastSize)
+      cachedSerializedMapStatus = serResult._1
+      cachedSerializedBroadcast = serResult._2
+    }
+    cachedSerializedMapStatus
+  }
+
+  // Used in testing.
+  def hasCachedSerializedBroadcast: Boolean = synchronized {
+    cachedSerializedBroadcast != null
+  }
+
+  /**
+   * Helper function which provides thread-safe access to the mapStatuses 
array.
+   * The function should NOT mutate the array.
+   */
+  def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized {
+    f(mapStatuses)
+  }
+
+  /**
+   * Clears the cached serialized map output statuses.
+   */
+  def invalidateSerializedMapOutputStatusCache(): Unit = synchronized {
+    if (cachedSerializedBroadcast != null) {
+      cachedSerializedBroadcast.destroy()
+      cachedSerializedBroadcast = null
+    }
+    cachedSerializedMapStatus = null
+  }
+}
+
 private[spark] sealed trait MapOutputTrackerMessage
 private[spark] case class GetMapOutputStatuses(shuffleId: Int)
   extends MapOutputTrackerMessage
@@ -62,37 +212,26 @@ private[spark] class MapOutputTrackerMasterEndpoint(
 }
 
 /**
- * Class that keeps track of the location of the map output of
- * a stage. This is abstract because different versions of MapOutputTracker
- * (driver and executor) use different HashMap to store its metadata.
- */
+ * Class that keeps track of the location of the map output of a stage. This 
is abstract because the
+ * driver and executor have different versions of the MapOutputTracker. In 
principle the driver-
+ * and executor-side classes don't need to share a common base class; the 
current shared base class
+ * is maintained primarily for backwards-compatibility in order to avoid 
having to update existing
+ * test code.
+*/
 private[spark] abstract class MapOutputTracker(conf: SparkConf) extends 
Logging {
-
   /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
   var trackerEndpoint: RpcEndpointRef = _
 
   /**
-   * This HashMap has different behavior for the driver and the executors.
-   *
-   * On the driver, it serves as the source of map outputs recorded from 
ShuffleMapTasks.
-   * On the executors, it simply serves as a cache, in which a miss triggers a 
fetch from the
-   * driver's corresponding HashMap.
-   *
-   * Note: because mapStatuses is accessed concurrently, subclasses should 
make sure it's a
-   * thread-safe map.
-   */
-  protected val mapStatuses: Map[Int, Array[MapStatus]]
-
-  /**
-   * Incremented every time a fetch fails so that client nodes know to clear
-   * their cache of map output locations if this happens.
+   * The driver-side counter is incremented every time that a map output is 
lost. This value is sent
+   * to executors as part of tasks, where executors compare the new epoch 
number to the highest
+   * epoch number that they received in the past. If the new epoch number is 
higher then executors
+   * will clear their local caches of map output statuses and will re-fetch 
(possibly updated)
+   * statuses from the driver.
    */
   protected var epoch: Long = 0
   protected val epochLock = new AnyRef
 
-  /** Remembers which map output locations are currently being fetched on an 
executor. */
-  private val fetching = new HashSet[Int]
-
   /**
    * Send a message to the trackerEndpoint and get its result within a default 
timeout, or
    * throw a SparkException if this fails.
@@ -116,14 +255,7 @@ private[spark] abstract class MapOutputTracker(conf: 
SparkConf) extends Logging
     }
   }
 
-  /**
-   * Called from executors to get the server URIs and output sizes for each 
shuffle block that
-   * needs to be read from a given reduce task.
-   *
-   * @return A sequence of 2-item tuples, where the first item in the tuple is 
a BlockManagerId,
-   *         and the second item is a sequence of (shuffle block id, shuffle 
block size) tuples
-   *         describing the shuffle blocks that are stored at that block 
manager.
-   */
+  // For testing
   def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
       : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
     getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
@@ -139,135 +271,31 @@ private[spark] abstract class MapOutputTracker(conf: 
SparkConf) extends Logging
    *         describing the shuffle blocks that are stored at that block 
manager.
    */
   def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, 
endPartition: Int)
-      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
-    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions 
$startPartition-$endPartition")
-    val statuses = getStatuses(shuffleId)
-    // Synchronize on the returned array because, on the driver, it gets 
mutated in place
-    statuses.synchronized {
-      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, 
endPartition, statuses)
-    }
-  }
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])]
 
   /**
-   * Return statistics about all of the outputs for a given shuffle.
+   * Deletes map output status information for the specified shuffle stage.
    */
-  def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
-    val statuses = getStatuses(dep.shuffleId)
-    // Synchronize on the returned array because, on the driver, it gets 
mutated in place
-    statuses.synchronized {
-      val totalSizes = new Array[Long](dep.partitioner.numPartitions)
-      for (s <- statuses) {
-        for (i <- 0 until totalSizes.length) {
-          totalSizes(i) += s.getSizeForBlock(i)
-        }
-      }
-      new MapOutputStatistics(dep.shuffleId, totalSizes)
-    }
-  }
-
-  /**
-   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: 
clients MUST synchronize
-   * on this array when reading it, because on the driver, we may be changing 
it in place.
-   *
-   * (It would be nice to remove this restriction in the future.)
-   */
-  private def getStatuses(shuffleId: Int): Array[MapStatus] = {
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching 
them")
-      val startTime = System.currentTimeMillis
-      var fetchedStatuses: Array[MapStatus] = null
-      fetching.synchronized {
-        // Someone else is fetching it; wait for them to be done
-        while (fetching.contains(shuffleId)) {
-          try {
-            fetching.wait()
-          } catch {
-            case e: InterruptedException =>
-          }
-        }
-
-        // Either while we waited the fetch happened successfully, or
-        // someone fetched it in between the get and the fetching.synchronized.
-        fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          // We have to do the fetch, get others to wait for us.
-          fetching += shuffleId
-        }
-      }
+  def unregisterShuffle(shuffleId: Int): Unit
 
-      if (fetchedStatuses == null) {
-        // We won the race to fetch the statuses; do so
-        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
-        // This try-finally prevents hangs due to timeouts:
-        try {
-          val fetchedBytes = 
askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          fetchedStatuses = 
MapOutputTracker.deserializeMapStatuses(fetchedBytes)
-          logInfo("Got the output locations")
-          mapStatuses.put(shuffleId, fetchedStatuses)
-        } finally {
-          fetching.synchronized {
-            fetching -= shuffleId
-            fetching.notifyAll()
-          }
-        }
-      }
-      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
-        s"${System.currentTimeMillis - startTime} ms")
-
-      if (fetchedStatuses != null) {
-        return fetchedStatuses
-      } else {
-        logError("Missing all output locations for shuffle " + shuffleId)
-        throw new MetadataFetchFailedException(
-          shuffleId, -1, "Missing all output locations for shuffle " + 
shuffleId)
-      }
-    } else {
-      return statuses
-    }
-  }
-
-  /** Called to get current epoch number. */
-  def getEpoch: Long = {
-    epochLock.synchronized {
-      return epoch
-    }
-  }
-
-  /**
-   * Called from executors to update the epoch number, potentially clearing 
old outputs
-   * because of a fetch failure. Each executor task calls this with the latest 
epoch
-   * number on the driver at the time it was created.
-   */
-  def updateEpoch(newEpoch: Long) {
-    epochLock.synchronized {
-      if (newEpoch > epoch) {
-        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
-        epoch = newEpoch
-        mapStatuses.clear()
-      }
-    }
-  }
-
-  /** Unregister shuffle data. */
-  def unregisterShuffle(shuffleId: Int) {
-    mapStatuses.remove(shuffleId)
-  }
-
-  /** Stop the tracker. */
-  def stop() { }
+  def stop() {}
 }
 
 /**
- * MapOutputTracker for the driver.
+ * Driver-side class that keeps track of the location of the map output of a 
stage.
+ *
+ * The DAGScheduler uses this class to (de)register map output statuses and to 
look up statistics
+ * for performing locality-aware reduce task scheduling.
+ *
+ * ShuffleMapStage uses this class for tracking available / missing outputs in 
order to determine
+ * which tasks need to be run.
  */
-private[spark] class MapOutputTrackerMaster(conf: SparkConf,
-    broadcastManager: BroadcastManager, isLocal: Boolean)
+private[spark] class MapOutputTrackerMaster(
+    conf: SparkConf,
+    broadcastManager: BroadcastManager,
+    isLocal: Boolean)
   extends MapOutputTracker(conf) {
 
-  /** Cache a serialized version of the output statuses for each shuffle to 
send them out faster */
-  private var cacheEpoch = epoch
-
   // The size at which we use Broadcast to send the map output statuses to the 
executors
   private val minSizeForBroadcast =
     conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", 
"512k").toInt
@@ -287,22 +315,12 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf,
   // can be read locally, but may lead to more delay in scheduling if those 
locations are busy.
   private val REDUCER_PREF_LOCS_FRACTION = 0.2
 
-  // HashMaps for storing mapStatuses and cached serialized statuses in the 
driver.
+  // HashMap for storing shuffleStatuses in the driver.
   // Statuses are dropped only by explicit de-registering.
-  protected val mapStatuses = new ConcurrentHashMap[Int, 
Array[MapStatus]]().asScala
-  private val cachedSerializedStatuses = new ConcurrentHashMap[Int, 
Array[Byte]]().asScala
+  private val shuffleStatuses = new ConcurrentHashMap[Int, 
ShuffleStatus]().asScala
 
   private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
 
-  // Kept in sync with cachedSerializedStatuses explicitly
-  // This is required so that the Broadcast variable remains in scope until we 
remove
-  // the shuffleId explicitly or implicitly.
-  private val cachedSerializedBroadcast = new HashMap[Int, 
Broadcast[Array[Byte]]]()
-
-  // This is to prevent multiple serializations of the same shuffle - which 
happens when
-  // there is a request storm when shuffle start.
-  private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]()
-
   // requests for map output statuses
   private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
 
@@ -348,8 +366,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
             val hostPort = context.senderAddress.hostPort
             logDebug("Handling request to send map output locations for 
shuffle " + shuffleId +
               " to " + hostPort)
-            val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
-            context.reply(mapOutputStatuses)
+            val shuffleStatus = shuffleStatuses.get(shuffleId).head
+            context.reply(
+              shuffleStatus.serializedMapStatus(broadcastManager, isLocal, 
minSizeForBroadcast))
           } catch {
             case NonFatal(e) => logError(e.getMessage, e)
           }
@@ -363,59 +382,77 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf,
   /** A poison endpoint that indicates MessageLoop should exit its message 
loop. */
   private val PoisonPill = new GetMapOutputMessage(-99, null)
 
-  // Exposed for testing
-  private[spark] def getNumCachedSerializedBroadcast = 
cachedSerializedBroadcast.size
+  // Used only in unit tests.
+  private[spark] def getNumCachedSerializedBroadcast: Int = {
+    shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
+  }
 
   def registerShuffle(shuffleId: Int, numMaps: Int) {
-    if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+    if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
       throw new IllegalArgumentException("Shuffle ID " + shuffleId + " 
registered twice")
     }
-    // add in advance
-    shuffleIdLocks.putIfAbsent(shuffleId, new Object())
   }
 
   def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
-    val array = mapStatuses(shuffleId)
-    array.synchronized {
-      array(mapId) = status
-    }
-  }
-
-  /** Register multiple map output information for the given shuffle */
-  def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], 
changeEpoch: Boolean = false) {
-    mapStatuses.put(shuffleId, statuses.clone())
-    if (changeEpoch) {
-      incrementEpoch()
-    }
+    shuffleStatuses(shuffleId).addMapOutput(mapId, status)
   }
 
   /** Unregister map output information of the given shuffle, mapper and block 
manager */
   def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: 
BlockManagerId) {
-    val arrayOpt = mapStatuses.get(shuffleId)
-    if (arrayOpt.isDefined && arrayOpt.get != null) {
-      val array = arrayOpt.get
-      array.synchronized {
-        if (array(mapId) != null && array(mapId).location == bmAddress) {
-          array(mapId) = null
-        }
-      }
-      incrementEpoch()
-    } else {
-      throw new SparkException("unregisterMapOutput called for nonexistent 
shuffle ID")
+    shuffleStatuses.get(shuffleId) match {
+      case Some(shuffleStatus) =>
+        shuffleStatus.removeMapOutput(mapId, bmAddress)
+        incrementEpoch()
+      case None =>
+        throw new SparkException("unregisterMapOutput called for nonexistent 
shuffle ID")
     }
   }
 
   /** Unregister shuffle data */
-  override def unregisterShuffle(shuffleId: Int) {
-    mapStatuses.remove(shuffleId)
-    cachedSerializedStatuses.remove(shuffleId)
-    cachedSerializedBroadcast.remove(shuffleId).foreach(v => 
removeBroadcast(v))
-    shuffleIdLocks.remove(shuffleId)
+  def unregisterShuffle(shuffleId: Int) {
+    shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
+      shuffleStatus.invalidateSerializedMapOutputStatusCache()
+    }
+  }
+
+  /**
+   * Removes all shuffle outputs associated with this executor. Note that this 
will also remove
+   * outputs which are served by an external shuffle server (if one exists), 
as they are still
+   * registered with this execId.
+   */
+  def removeOutputsOnExecutor(execId: String): Unit = {
+    shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) 
}
+    incrementEpoch()
   }
 
   /** Check if the given shuffle is being tracked */
-  def containsShuffle(shuffleId: Int): Boolean = {
-    cachedSerializedStatuses.contains(shuffleId) || 
mapStatuses.contains(shuffleId)
+  def containsShuffle(shuffleId: Int): Boolean = 
shuffleStatuses.contains(shuffleId)
+
+  def getNumAvailableOutputs(shuffleId: Int): Int = {
+    shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0)
+  }
+
+  /**
+   * Returns the sequence of partition ids that are missing (i.e. needs to be 
computed), or None
+   * if the MapOutputTrackerMaster doesn't know about this shuffle.
+   */
+  def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = {
+    shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
+  }
+
+  /**
+   * Return statistics about all of the outputs for a given shuffle.
+   */
+  def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
+    shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
+      val totalSizes = new Array[Long](dep.partitioner.numPartitions)
+      for (s <- statuses) {
+        for (i <- 0 until totalSizes.length) {
+          totalSizes(i) += s.getSizeForBlock(i)
+        }
+      }
+      new MapOutputStatistics(dep.shuffleId, totalSizes)
+    }
   }
 
   /**
@@ -459,9 +496,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
       fractionThreshold: Double)
     : Option[Array[BlockManagerId]] = {
 
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses != null) {
-      statuses.synchronized {
+    val shuffleStatus = shuffleStatuses.get(shuffleId).orNull
+    if (shuffleStatus != null) {
+      shuffleStatus.withMapStatuses { statuses =>
         if (statuses.nonEmpty) {
           // HashMap to add up sizes of all blocks at the same location
           val locs = new HashMap[BlockManagerId, Long]
@@ -502,77 +539,24 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf,
     }
   }
 
-  private def removeBroadcast(bcast: Broadcast[_]): Unit = {
-    if (null != bcast) {
-      broadcastManager.unbroadcast(bcast.id,
-        removeFromDriver = true, blocking = false)
+  /** Called to get current epoch number. */
+  def getEpoch: Long = {
+    epochLock.synchronized {
+      return epoch
     }
   }
 
-  private def clearCachedBroadcast(): Unit = {
-    for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2)
-    cachedSerializedBroadcast.clear()
-  }
-
-  def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
-    var statuses: Array[MapStatus] = null
-    var retBytes: Array[Byte] = null
-    var epochGotten: Long = -1
-
-    // Check to see if we have a cached version, returns true if it does
-    // and has side effect of setting retBytes.  If not returns false
-    // with side effect of setting statuses
-    def checkCachedStatuses(): Boolean = {
-      epochLock.synchronized {
-        if (epoch > cacheEpoch) {
-          cachedSerializedStatuses.clear()
-          clearCachedBroadcast()
-          cacheEpoch = epoch
-        }
-        cachedSerializedStatuses.get(shuffleId) match {
-          case Some(bytes) =>
-            retBytes = bytes
-            true
-          case None =>
-            logDebug("cached status not found for : " + shuffleId)
-            statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus])
-            epochGotten = epoch
-            false
-        }
-      }
-    }
-
-    if (checkCachedStatuses()) return retBytes
-    var shuffleIdLock = shuffleIdLocks.get(shuffleId)
-    if (null == shuffleIdLock) {
-      val newLock = new Object()
-      // in general, this condition should be false - but good to be paranoid
-      val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
-      shuffleIdLock = if (null != prevLock) prevLock else newLock
-    }
-    // synchronize so we only serialize/broadcast it once since multiple 
threads call
-    // in parallel
-    shuffleIdLock.synchronized {
-      // double check to make sure someone else didn't serialize and cache the 
same
-      // mapstatus while we were waiting on the synchronize
-      if (checkCachedStatuses()) return retBytes
-
-      // If we got here, we failed to find the serialized locations in the 
cache, so we pulled
-      // out a snapshot of the locations as "statuses"; let's serialize and 
return that
-      val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, 
broadcastManager,
-        isLocal, minSizeForBroadcast)
-      logInfo("Size of output statuses for shuffle %d is %d 
bytes".format(shuffleId, bytes.length))
-      // Add them into the table only if the epoch hasn't changed while we 
were working
-      epochLock.synchronized {
-        if (epoch == epochGotten) {
-          cachedSerializedStatuses(shuffleId) = bytes
-          if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
-        } else {
-          logInfo("Epoch changed, not caching!")
-          removeBroadcast(bcast)
+  // This method is only called in local-mode.
+  def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, 
endPartition: Int)
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions 
$startPartition-$endPartition")
+    shuffleStatuses.get(shuffleId) match {
+      case Some (shuffleStatus) =>
+        shuffleStatus.withMapStatuses { statuses =>
+          MapOutputTracker.convertMapStatuses(shuffleId, startPartition, 
endPartition, statuses)
         }
-      }
-      bytes
+      case None =>
+        Seq.empty
     }
   }
 
@@ -580,21 +564,121 @@ private[spark] class MapOutputTrackerMaster(conf: 
SparkConf,
     mapOutputRequests.offer(PoisonPill)
     threadpool.shutdown()
     sendTracker(StopMapOutputTracker)
-    mapStatuses.clear()
     trackerEndpoint = null
-    cachedSerializedStatuses.clear()
-    clearCachedBroadcast()
-    shuffleIdLocks.clear()
+    shuffleStatuses.clear()
   }
 }
 
 /**
- * MapOutputTracker for the executors, which fetches map output information 
from the driver's
- * MapOutputTrackerMaster.
+ * Executor-side client for fetching map output info from the driver's 
MapOutputTrackerMaster.
+ * Note that this is not used in local-mode; instead, local-mode Executors 
access the
+ * MapOutputTrackerMaster directly (which is possible because the master and 
worker share a comon
+ * superclass).
  */
 private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends 
MapOutputTracker(conf) {
-  protected val mapStatuses: Map[Int, Array[MapStatus]] =
+
+  val mapStatuses: Map[Int, Array[MapStatus]] =
     new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
+
+  /** Remembers which map output locations are currently being fetched on an 
executor. */
+  private val fetching = new HashSet[Int]
+
+  override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, 
endPartition: Int)
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions 
$startPartition-$endPartition")
+    val statuses = getStatuses(shuffleId)
+    try {
+      MapOutputTracker.convertMapStatuses(shuffleId, startPartition, 
endPartition, statuses)
+    } catch {
+      case e: MetadataFetchFailedException =>
+        // We experienced a fetch failure so our mapStatuses cache is 
outdated; clear it:
+        mapStatuses.clear()
+        throw e
+    }
+  }
+
+  /**
+   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: 
clients MUST synchronize
+   * on this array when reading it, because on the driver, we may be changing 
it in place.
+   *
+   * (It would be nice to remove this restriction in the future.)
+   */
+  private def getStatuses(shuffleId: Int): Array[MapStatus] = {
+    val statuses = mapStatuses.get(shuffleId).orNull
+    if (statuses == null) {
+      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching 
them")
+      val startTime = System.currentTimeMillis
+      var fetchedStatuses: Array[MapStatus] = null
+      fetching.synchronized {
+        // Someone else is fetching it; wait for them to be done
+        while (fetching.contains(shuffleId)) {
+          try {
+            fetching.wait()
+          } catch {
+            case e: InterruptedException =>
+          }
+        }
+
+        // Either while we waited the fetch happened successfully, or
+        // someone fetched it in between the get and the fetching.synchronized.
+        fetchedStatuses = mapStatuses.get(shuffleId).orNull
+        if (fetchedStatuses == null) {
+          // We have to do the fetch, get others to wait for us.
+          fetching += shuffleId
+        }
+      }
+
+      if (fetchedStatuses == null) {
+        // We won the race to fetch the statuses; do so
+        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+        // This try-finally prevents hangs due to timeouts:
+        try {
+          val fetchedBytes = 
askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
+          fetchedStatuses = 
MapOutputTracker.deserializeMapStatuses(fetchedBytes)
+          logInfo("Got the output locations")
+          mapStatuses.put(shuffleId, fetchedStatuses)
+        } finally {
+          fetching.synchronized {
+            fetching -= shuffleId
+            fetching.notifyAll()
+          }
+        }
+      }
+      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
+        s"${System.currentTimeMillis - startTime} ms")
+
+      if (fetchedStatuses != null) {
+        fetchedStatuses
+      } else {
+        logError("Missing all output locations for shuffle " + shuffleId)
+        throw new MetadataFetchFailedException(
+          shuffleId, -1, "Missing all output locations for shuffle " + 
shuffleId)
+      }
+    } else {
+      statuses
+    }
+  }
+
+
+  /** Unregister shuffle data. */
+  def unregisterShuffle(shuffleId: Int): Unit = {
+    mapStatuses.remove(shuffleId)
+  }
+
+  /**
+   * Called from executors to update the epoch number, potentially clearing 
old outputs
+   * because of a fetch failure. Each executor task calls this with the latest 
epoch
+   * number on the driver at the time it was created.
+   */
+  def updateEpoch(newEpoch: Long): Unit = {
+    epochLock.synchronized {
+      if (newEpoch > epoch) {
+        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
+        epoch = newEpoch
+        mapStatuses.clear()
+      }
+    }
+  }
 }
 
 private[spark] object MapOutputTracker extends Logging {
@@ -683,7 +767,7 @@ private[spark] object MapOutputTracker extends Logging {
    *         and the second item is a sequence of (shuffle block ID, shuffle 
block size) tuples
    *         describing the shuffle blocks that are stored at that block 
manager.
    */
-  private def convertMapStatuses(
+  def convertMapStatuses(
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/Partitioner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala 
b/core/src/main/scala/org/apache/spark/Partitioner.scala
index f83f527..93a9337 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -32,6 +32,9 @@ import org.apache.spark.util.random.SamplingUtils
 /**
  * An object that defines how the elements in a key-value pair RDD are 
partitioned by key.
  * Maps each key to a partition ID, from 0 to `numPartitions - 1`.
+ *
+ * Note that, partitioner must be deterministic, i.e. it must return the same 
partition id given
+ * the same partition key.
  */
 abstract class Partitioner extends Serializable {
   def numPartitions: Int

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 36b1743..47c51c0 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -325,8 +325,14 @@ private[spark] class Executor(
           throw new TaskKilledException(killReason.get)
         }
 
-        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
-        env.mapOutputTracker.updateEpoch(task.epoch)
+        // The purpose of updating the epoch here is to invalidate executor 
map output status cache
+        // in case FetchFailures have occurred. In local mode 
`env.mapOutputTracker` will be
+        // MapOutputTrackerMaster and its cache invalidation is not based on 
epoch numbers so
+        // we don't need to make any special calls here.
+        if (!isLocal) {
+          logDebug("Task " + taskId + "'s epoch is " + task.epoch)
+          
env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
+        }
 
         // Run the actual task and measure its runtime.
         taskStart = System.currentTimeMillis()

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index e4587c9..15128f0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -23,11 +23,22 @@ import org.apache.spark.{Partition, TaskContext}
 
 /**
  * An RDD that applies the provided function to every partition of the parent 
RDD.
+ *
+ * @param prev the parent RDD.
+ * @param f The function used to map a tuple of (TaskContext, partition index, 
input iterator) to
+ *          an output iterator.
+ * @param preservesPartitioning Whether the input function preserves the 
partitioner, which should
+ *                              be `false` unless `prev` is a pair RDD and the 
input function
+ *                              doesn't modify the keys.
+ * @param isOrderSensitive whether or not the function is order-sensitive. If 
it's order
+ *                         sensitive, it may return totally different result 
when the input order
+ *                         is changed. Mostly stateful functions are 
order-sensitive.
  */
 private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
     var prev: RDD[T],
     f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, 
partition index, iterator)
-    preservesPartitioning: Boolean = false)
+    preservesPartitioning: Boolean = false,
+    isOrderSensitive: Boolean = false)
   extends RDD[U](prev) {
 
   override val partitioner = if (preservesPartitioning) 
firstParent[T].partitioner else None
@@ -41,4 +52,12 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: 
ClassTag](
     super.clearDependencies()
     prev = null
   }
+
+  override protected def getOutputDeterministicLevel = {
+    if (isOrderSensitive && prev.outputDeterministicLevel == 
DeterministicLevel.UNORDERED) {
+      DeterministicLevel.INDETERMINATE
+    } else {
+      super.getOutputDeterministicLevel
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 102836d..4ff0f83 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -461,8 +461,9 @@ abstract class RDD[T: ClassTag](
 
       // include a shuffle step so that our upstream tasks are still 
distributed
       new CoalescedRDD(
-        new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
-        new HashPartitioner(numPartitions)),
+        new ShuffledRDD[Int, T, T](
+          mapPartitionsWithIndexInternal(distributePartition, isOrderSensitive 
= true),
+          new HashPartitioner(numPartitions)),
         numPartitions,
         partitionCoalescer).values
     } else {
@@ -806,16 +807,21 @@ abstract class RDD[T: ClassTag](
    * serializable and don't require closure cleaning.
    *
    * @param preservesPartitioning indicates whether the input function 
preserves the partitioner,
-   * which should be `false` unless this is a pair RDD and the input function 
doesn't modify
-   * the keys.
+   *                              which should be `false` unless this is a 
pair RDD and the input
+   *                              function doesn't modify the keys.
+   * @param isOrderSensitive whether or not the function is order-sensitive. 
If it's order
+   *                         sensitive, it may return totally different result 
when the input order
+   *                         is changed. Mostly stateful functions are 
order-sensitive.
    */
   private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
       f: (Int, Iterator[T]) => Iterator[U],
-      preservesPartitioning: Boolean = false): RDD[U] = withScope {
+      preservesPartitioning: Boolean = false,
+      isOrderSensitive: Boolean = false): RDD[U] = withScope {
     new MapPartitionsRDD(
       this,
       (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
-      preservesPartitioning)
+      preservesPartitioning = preservesPartitioning,
+      isOrderSensitive = isOrderSensitive)
   }
 
   /**
@@ -1635,6 +1641,16 @@ abstract class RDD[T: ClassTag](
   }
 
   /**
+   * Return whether this RDD is reliably checkpointed and materialized.
+   */
+  private[rdd] def isReliablyCheckpointed: Boolean = {
+    checkpointData match {
+      case Some(reliable: ReliableRDDCheckpointData[_]) if 
reliable.isCheckpointed => true
+      case _ => false
+    }
+  }
+
+  /**
    * Gets the name of the directory to which this RDD was checkpointed.
    * This is not defined if the RDD is checkpointed locally.
    */
@@ -1838,6 +1854,63 @@ abstract class RDD[T: ClassTag](
   def toJavaRDD() : JavaRDD[T] = {
     new JavaRDD(this)(elementClassTag)
   }
+
+  /**
+   * Returns the deterministic level of this RDD's output. Please refer to 
[[DeterministicLevel]]
+   * for the definition.
+   *
+   * By default, an reliably checkpointed RDD, or RDD without parents(root 
RDD) is DETERMINATE. For
+   * RDDs with parents, we will generate a deterministic level candidate per 
parent according to
+   * the dependency. The deterministic level of the current RDD is the 
deterministic level
+   * candidate that is deterministic least. Please override 
[[getOutputDeterministicLevel]] to
+   * provide custom logic of calculating output deterministic level.
+   */
+  // TODO: make it public so users can set deterministic level to their custom 
RDDs.
+  // TODO: this can be per-partition. e.g. UnionRDD can have different 
deterministic level for
+  // different partitions.
+  private[spark] final lazy val outputDeterministicLevel: 
DeterministicLevel.Value = {
+    if (isReliablyCheckpointed) {
+      DeterministicLevel.DETERMINATE
+    } else {
+      getOutputDeterministicLevel
+    }
+  }
+
+  @DeveloperApi
+  protected def getOutputDeterministicLevel: DeterministicLevel.Value = {
+    val deterministicLevelCandidates = dependencies.map {
+      // The shuffle is not really happening, treat it like narrow dependency 
and assume the output
+      // deterministic level of current RDD is same as parent.
+      case dep: ShuffleDependency[_, _, _] if dep.rdd.partitioner.exists(_ == 
dep.partitioner) =>
+        dep.rdd.outputDeterministicLevel
+
+      case dep: ShuffleDependency[_, _, _] =>
+        if (dep.rdd.outputDeterministicLevel == 
DeterministicLevel.INDETERMINATE) {
+          // If map output was indeterminate, shuffle output will be 
indeterminate as well
+          DeterministicLevel.INDETERMINATE
+        } else if (dep.keyOrdering.isDefined && dep.aggregator.isDefined) {
+          // if aggregator specified (and so unique keys) and key ordering 
specified - then
+          // consistent ordering.
+          DeterministicLevel.DETERMINATE
+        } else {
+          // In Spark, the reducer fetches multiple remote shuffle blocks at 
the same time, and
+          // the arrival order of these shuffle blocks are totally random. 
Even if the parent map
+          // RDD is DETERMINATE, the reduce RDD is always UNORDERED.
+          DeterministicLevel.UNORDERED
+        }
+
+      // For narrow dependency, assume the output deterministic level of 
current RDD is same as
+      // parent.
+      case dep => dep.rdd.outputDeterministicLevel
+    }
+
+    if (deterministicLevelCandidates.isEmpty) {
+      // By default we assume the root RDD is determinate.
+      DeterministicLevel.DETERMINATE
+    } else {
+      deterministicLevelCandidates.maxBy(_.id)
+    }
+  }
 }
 
 
@@ -1891,3 +1964,18 @@ object RDD {
     new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
   }
 }
+
+/**
+ * The deterministic level of RDD's output (i.e. what `RDD#compute` returns). 
This explains how
+ * the output will diff when Spark reruns the tasks for the RDD. There are 3 
deterministic levels:
+ * 1. DETERMINATE: The RDD output is always the same data set in the same 
order after a rerun.
+ * 2. UNORDERED: The RDD output is always the same data set but the order can 
be different
+ *               after a rerun.
+ * 3. INDETERMINATE. The RDD output can be different after a rerun.
+ *
+ * Note that, the output of an RDD usually relies on the parent RDDs. When the 
parent RDD's output
+ * is INDETERMINATE, it's very likely the RDD's output is also INDETERMINATE.
+ */
+private[spark] object DeterministicLevel extends Enumeration {
+  val DETERMINATE, UNORDERED, INDETERMINATE = Value
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 099bc2e..cb6cdcd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.partial.{ApproximateActionListener, 
ApproximateEvaluator, PartialResult}
-import org.apache.spark.rdd.{RDD, RDDCheckpointData}
+import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData}
 import org.apache.spark.rpc.RpcTimeout
 import org.apache.spark.storage._
 import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -328,25 +328,14 @@ class DAGScheduler(
     val numTasks = rdd.partitions.length
     val parents = getOrCreateParentStages(rdd, jobId)
     val id = nextStageId.getAndIncrement()
-    val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, 
rdd.creationSite, shuffleDep)
+    val stage = new ShuffleMapStage(
+      id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, 
mapOutputTracker)
 
     stageIdToStage(id) = stage
     shuffleIdToMapStage(shuffleDep.shuffleId) = stage
     updateJobIdStageIdMaps(jobId, stage)
 
-    if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
-      // A previously run stage generated partitions for this shuffle, so for 
each output
-      // that's still available, copy information about that output location 
to the new stage
-      // (so we don't unnecessarily re-compute that data).
-      val serLocs = 
mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
-      val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
-      (0 until locs.length).foreach { i =>
-        if (locs(i) ne null) {
-          // locs(i) will be null if missing
-          stage.addOutputLoc(i, locs(i))
-        }
-      }
-    } else {
+    if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
       // Kind of ugly: need to register RDDs with the cache and map output 
tracker here
       // since we can't do it in the RDD constructor because # of partitions 
is unknown
       logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
@@ -1240,7 +1229,8 @@ class DAGScheduler(
               // The epoch of the task is acceptable (i.e., the task was 
launched after the most
               // recent failure we're aware of for the executor), so mark the 
task's output as
               // available.
-              shuffleStage.addOutputLoc(smt.partitionId, status)
+              mapOutputTracker.registerMapOutput(
+                shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
               // Remove the task's partition from pending partitions. This may 
have already been
               // done above, but will not have been done yet in cases where 
the task attempt was
               // from an earlier attempt of the stage (i.e., not the attempt 
that's currently
@@ -1257,16 +1247,14 @@ class DAGScheduler(
               logInfo("waiting: " + waitingStages)
               logInfo("failed: " + failedStages)
 
-              // We supply true to increment the epoch number here in case 
this is a
-              // recomputation of the map outputs. In that case, some nodes 
may have cached
-              // locations with holes (from when we detected the error) and 
will need the
-              // epoch incremented to refetch them.
-              // TODO: Only increment the epoch number if this is not the 
first time
-              //       we registered these map outputs.
-              mapOutputTracker.registerMapOutputs(
-                shuffleStage.shuffleDep.shuffleId,
-                shuffleStage.outputLocInMapOutputTrackerFormat(),
-                changeEpoch = true)
+              // This call to increment the epoch may not be strictly 
necessary, but it is retained
+              // for now in order to minimize the changes in behavior from an 
earlier version of the
+              // code. This existing behavior of always incrementing the epoch 
following any
+              // successful shuffle map stage completion may have benefits by 
causing unneeded
+              // cached map outputs to be cleaned up earlier on executors. In 
the future we can
+              // consider removing this call, but this will require some extra 
investigation.
+              // See 
https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
+              mapOutputTracker.incrementEpoch()
 
               clearCacheLocs()
 
@@ -1344,6 +1332,63 @@ class DAGScheduler(
             failedStages += failedStage
             failedStages += mapStage
             if (noResubmitEnqueued) {
+              // If the map stage is INDETERMINATE, which means the map tasks 
may return
+              // different result when re-try, we need to re-try all the tasks 
of the failed
+              // stage and its succeeding stages, because the input data will 
be changed after the
+              // map tasks are re-tried.
+              // Note that, if map stage is UNORDERED, we are fine. The 
shuffle partitioner is
+              // guaranteed to be determinate, so the input data of the 
reducers will not change
+              // even if the map tasks are re-tried.
+              if (mapStage.rdd.outputDeterministicLevel == 
DeterministicLevel.INDETERMINATE) {
+                // It's a little tricky to find all the succeeding stages of 
`failedStage`, because
+                // each stage only know its parents not children. Here we 
traverse the stages from
+                // the leaf nodes (the result stages of active jobs), and 
rollback all the stages
+                // in the stage chains that connect to the `failedStage`. To 
speed up the stage
+                // traversing, we collect the stages to rollback first. If a 
stage needs to
+                // rollback, all its succeeding stages need to rollback to.
+                val stagesToRollback = 
scala.collection.mutable.HashSet(failedStage)
+
+                def collectStagesToRollback(stageChain: List[Stage]): Unit = {
+                  if (stagesToRollback.contains(stageChain.head)) {
+                    stageChain.drop(1).foreach(s => stagesToRollback += s)
+                  } else {
+                    stageChain.head.parents.foreach { s =>
+                      collectStagesToRollback(s :: stageChain)
+                    }
+                  }
+                }
+
+                def generateErrorMessage(stage: Stage): String = {
+                  "A shuffle map stage with indeterminate output was failed 
and retried. " +
+                    s"However, Spark cannot rollback the $stage to re-process 
the input data, " +
+                    "and has to fail this job. Please eliminate the 
indeterminacy by " +
+                    "checkpointing the RDD before repartition and try again."
+                }
+
+                activeJobs.foreach(job => 
collectStagesToRollback(job.finalStage :: Nil))
+
+                stagesToRollback.foreach {
+                  case mapStage: ShuffleMapStage =>
+                    val numMissingPartitions = 
mapStage.findMissingPartitions().length
+                    if (numMissingPartitions < mapStage.numTasks) {
+                      // TODO: support to rollback shuffle files.
+                      // Currently the shuffle writing is "first write wins", 
so we can't re-run a
+                      // shuffle map stage and overwrite existing shuffle 
files. We have to finish
+                      // SPARK-8029 first.
+                      abortStage(mapStage, generateErrorMessage(mapStage), 
None)
+                    }
+
+                  case resultStage: ResultStage if 
resultStage.activeJob.isDefined =>
+                    val numMissingPartitions = 
resultStage.findMissingPartitions().length
+                    if (numMissingPartitions < resultStage.numTasks) {
+                      // TODO: support to rollback result tasks.
+                      abortStage(resultStage, 
generateErrorMessage(resultStage), None)
+                    }
+
+                  case _ =>
+                }
+              }
+
               // We expect one executor failure to trigger many FetchFailures 
in rapid succession,
               // but all of those task failures can typically be handled by a 
single resubmission of
               // the failed stage.  We avoid flooding the scheduler's event 
queue with resubmit
@@ -1367,7 +1412,6 @@ class DAGScheduler(
           }
           // Mark the map whose fetch failed as broken in the map stage
           if (mapId != -1) {
-            mapStage.removeOutputLoc(mapId, bmAddress)
             mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
           }
 
@@ -1416,17 +1460,7 @@ class DAGScheduler(
 
       if (filesLost || !env.blockManager.externalShuffleServiceEnabled) {
         logInfo("Shuffle files lost for executor: %s (epoch 
%d)".format(execId, currentEpoch))
-        // TODO: This will be really slow if we keep accumulating shuffle map 
stages
-        for ((shuffleId, stage) <- shuffleIdToMapStage) {
-          stage.removeOutputsOnExecutor(execId)
-          mapOutputTracker.registerMapOutputs(
-            shuffleId,
-            stage.outputLocInMapOutputTrackerFormat(),
-            changeEpoch = true)
-        }
-        if (shuffleIdToMapStage.isEmpty) {
-          mapOutputTracker.incrementEpoch()
-        }
+        mapOutputTracker.removeOutputsOnExecutor(execId)
         clearCacheLocs()
       }
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala 
b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index db4d9ef..05f650f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -19,9 +19,8 @@ package org.apache.spark.scheduler
 
 import scala.collection.mutable.HashSet
 
-import org.apache.spark.ShuffleDependency
+import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.BlockManagerId
 import org.apache.spark.util.CallSite
 
 /**
@@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage(
     parents: List[Stage],
     firstJobId: Int,
     callSite: CallSite,
-    val shuffleDep: ShuffleDependency[_, _, _])
+    val shuffleDep: ShuffleDependency[_, _, _],
+    mapOutputTrackerMaster: MapOutputTrackerMaster)
   extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) {
 
   private[this] var _mapStageJobs: List[ActiveJob] = Nil
 
-  private[this] var _numAvailableOutputs: Int = 0
-
   /**
    * Partitions that either haven't yet been computed, or that were computed 
on an executor
    * that has since been lost, so should be re-computed.  This variable is 
used by the
@@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage(
    */
   val pendingPartitions = new HashSet[Int]
 
-  /**
-   * List of [[MapStatus]] for each partition. The index of the array is the 
map partition id,
-   * and each value in the array is the list of possible [[MapStatus]] for a 
partition
-   * (a single task might run multiple times).
-   */
-  private[this] val outputLocs = 
Array.fill[List[MapStatus]](numPartitions)(Nil)
-
   override def toString: String = "ShuffleMapStage " + id
 
   /**
@@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage(
   /**
    * Number of partitions that have shuffle outputs.
    * When this reaches [[numPartitions]], this map stage is ready.
-   * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
    */
-  def numAvailableOutputs: Int = _numAvailableOutputs
+  def numAvailableOutputs: Int = 
mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId)
 
   /**
    * Returns true if the map stage is ready, i.e. all partitions have shuffle 
outputs.
-   * This should be the same as `outputLocs.contains(Nil)`.
    */
-  def isAvailable: Boolean = _numAvailableOutputs == numPartitions
+  def isAvailable: Boolean = numAvailableOutputs == numPartitions
 
   /** Returns the sequence of partition ids that are missing (i.e. needs to be 
computed). */
   override def findMissingPartitions(): Seq[Int] = {
-    val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
-    assert(missing.size == numPartitions - _numAvailableOutputs,
-      s"${missing.size} missing, expected ${numPartitions - 
_numAvailableOutputs}")
-    missing
-  }
-
-  def addOutputLoc(partition: Int, status: MapStatus): Unit = {
-    val prevList = outputLocs(partition)
-    outputLocs(partition) = status :: prevList
-    if (prevList == Nil) {
-      _numAvailableOutputs += 1
-    }
-  }
-
-  def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = {
-    val prevList = outputLocs(partition)
-    val newList = prevList.filterNot(_.location == bmAddress)
-    outputLocs(partition) = newList
-    if (prevList != Nil && newList == Nil) {
-      _numAvailableOutputs -= 1
-    }
-  }
-
-  /**
-   * Returns an array of [[MapStatus]] (index by partition id). For each 
partition, the returned
-   * value contains only one (i.e. the first) [[MapStatus]]. If there is no 
entry for the partition,
-   * that position is filled with null.
-   */
-  def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = {
-    outputLocs.map(_.headOption.orNull)
-  }
-
-  /**
-   * Removes all shuffle outputs associated with this executor. Note that this 
will also remove
-   * outputs which are served by an external shuffle server (if one exists), 
as they are still
-   * registered with this execId.
-   */
-  def removeOutputsOnExecutor(execId: String): Unit = {
-    var becameUnavailable = false
-    for (partition <- 0 until numPartitions) {
-      val prevList = outputLocs(partition)
-      val newList = prevList.filterNot(_.location.executorId == execId)
-      outputLocs(partition) = newList
-      if (prevList != Nil && newList == Nil) {
-        becameUnavailable = true
-        _numAvailableOutputs -= 1
-      }
-    }
-    if (becameUnavailable) {
-      logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
-        this, execId, _numAvailableOutputs, numPartitions, isAvailable))
-    }
+    mapOutputTrackerMaster
+      .findMissingPartitions(shuffleDep.shuffleId)
+      .getOrElse(0 until numPartitions)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index f8c62b4..bc0d470 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
 
   var backend: SchedulerBackend = null
 
-  val mapOutputTracker = SparkEnv.get.mapOutputTracker
+  val mapOutputTracker = 
SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
 
   private var schedulableBuilder: SchedulableBuilder = null
   // default scheduler is FIFO

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala 
b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index ca94fd1..82b6fd1 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -138,21 +138,21 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       slaveRpcEnv.setupEndpointRef(rpcEnv.address, 
MapOutputTracker.ENDPOINT_NAME)
 
     masterTracker.registerShuffle(10, 1)
-    masterTracker.incrementEpoch()
     slaveTracker.updateEpoch(masterTracker.getEpoch)
+    // This is expected to fail because no outputs have been registered for 
the shuffle.
     intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 
0) }
 
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
     masterTracker.registerMapOutput(10, 0, MapStatus(
       BlockManagerId("a", "hostA", 1000), Array(1000L)))
-    masterTracker.incrementEpoch()
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
       Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 
0, 0), size1000)))))
     assert(0 == masterTracker.getNumCachedSerializedBroadcast)
 
+    val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch
     masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 
1000))
-    masterTracker.incrementEpoch()
+    assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput)
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 
0) }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala 
b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 3b564df..62c40d1 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -333,6 +333,7 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
     val shuffleMapRdd = new MyRDD(sc, 1, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(1))
     val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+    mapTrackerMaster.registerShuffle(0, 1)
 
     // first attempt -- its successful
     val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
@@ -367,7 +368,7 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
 
     // register one of the map outputs -- doesn't matter which one
     mapOutput1.foreach { case mapStatus =>
-      mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
+      mapTrackerMaster.registerMapOutput(0, 0, mapStatus)
     }
 
     val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
index 2b18ebe..571c6bb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
@@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with 
BeforeAndAfterEach with M
     sc = new SparkContext(conf)
     val scheduler = mock[TaskSchedulerImpl]
     when(scheduler.sc).thenReturn(sc)
-    when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker)
+    when(scheduler.mapOutputTracker).thenReturn(
+      SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])
     scheduler
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 9112065..1fff0d0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -30,7 +30,7 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark._
 import org.apache.spark.broadcast.BroadcastManager
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.shuffle.{FetchFailedException, 
MetadataFetchFailedException}
 import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
@@ -56,6 +56,20 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: 
DAGScheduler)
 
 }
 
+class MyCheckpointRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends MyRDD(sc, numPartitions, dependencies, locations, tracker, 
indeterminate) {
+
+  // Allow doCheckpoint() on this RDD.
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, 
Int)] =
+    Iterator.empty
+}
+
 /**
  * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
  * preferredLocations (if any) that are passed to them. They are deliberately 
not executable
@@ -70,7 +84,8 @@ class MyRDD(
     numPartitions: Int,
     dependencies: List[Dependency[_]],
     locations: Seq[Seq[String]] = Nil,
-    @(transient @param) tracker: MapOutputTrackerMaster = null)
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
   extends RDD[(Int, Int)](sc, dependencies) with Serializable {
 
   override def compute(split: Partition, context: TaskContext): Iterator[(Int, 
Int)] =
@@ -80,6 +95,10 @@ class MyRDD(
     override def index: Int = i
   }).toArray
 
+  override protected def getOutputDeterministicLevel = {
+    if (indeterminate) DeterministicLevel.INDETERMINATE else 
super.getOutputDeterministicLevel
+  }
+
   override def getPreferredLocations(partition: Partition): Seq[String] = {
     if (locations.isDefinedAt(partition.index)) {
       locations(partition.index)
@@ -2307,6 +2326,152 @@ class DAGSchedulerSuite extends SparkFunSuite with 
LocalSparkContext with Timeou
     }
   }
 
+  test("SPARK-23207: retry all the succeeding stages when the map stage is 
indeterminate") {
+    val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
+
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(2))
+    val shuffleId1 = shuffleDep1.shuffleId
+    val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = 
mapOutputTracker)
+
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(2))
+    val shuffleId2 = shuffleDep2.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = 
mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    // Finish the first shuffle map stage.
+    complete(taskSets(0), Seq(
+      (Success, makeMapStatus("hostA", 2)),
+      (Success, makeMapStatus("hostB", 2))))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId1) === 
Some(Seq.empty))
+
+    // Finish the second shuffle map stage.
+    complete(taskSets(1), Seq(
+      (Success, makeMapStatus("hostC", 2)),
+      (Success, makeMapStatus("hostD", 2))))
+    assert(mapOutputTracker.findMissingPartitions(shuffleId2) === 
Some(Seq.empty))
+
+    // The first task of the final stage failed with fetch failure
+    runEvent(makeCompletionEvent(
+      taskSets(2).tasks(0),
+      FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"),
+      null))
+
+    val failedStages = scheduler.failedStages.toSeq
+    assert(failedStages.length == 2)
+    // Shuffle blocks of "hostC" is lost, so first task of the 
`shuffleMapRdd2` needs to retry.
+    assert(failedStages.collect {
+      case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 
=> stage
+    }.head.findMissingPartitions() == Seq(0))
+    // The result stage is still waiting for its 2 tasks to complete
+    assert(failedStages.collect {
+      case stage: ResultStage => stage
+    }.head.findMissingPartitions() == Seq(0, 1))
+
+    scheduler.resubmitFailedStages()
+
+    // The first task of the `shuffleMapRdd2` failed with fetch failure
+    runEvent(makeCompletionEvent(
+      taskSets(3).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"),
+      null))
+
+    // The job should fail because Spark can't rollback the shuffle map stage.
+    assert(failure != null && failure.getMessage.contains("Spark cannot 
rollback"))
+  }
+
+  private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = {
+    val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2))
+    val shuffleId = shuffleDep.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, 
numShufflePartitions = 2)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === 
Some(Seq.empty))
+
+    // Finish the first task of the result stage
+    runEvent(makeCompletionEvent(
+      taskSets.last.tasks(0), Success, 42,
+      Seq.empty, createFakeTaskInfoWithId(0)))
+
+    // Fail the second task with FetchFailed.
+    runEvent(makeCompletionEvent(
+      taskSets.last.tasks(1),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+      null))
+
+    // The job should fail because Spark can't rollback the result stage.
+    assert(failure != null && failure.getMessage.contains("Spark cannot 
rollback"))
+  }
+
+  test("SPARK-23207: cannot rollback a result stage") {
+    val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true)
+    assertResultStageFailToRollback(shuffleMapRdd)
+  }
+
+  test("SPARK-23207: local checkpoint fail to rollback (checkpointed before)") 
{
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+    shuffleMapRdd.localCheckpoint()
+    shuffleMapRdd.doCheckpoint()
+    assertResultStageFailToRollback(shuffleMapRdd)
+  }
+
+  test("SPARK-23207: local checkpoint fail to rollback (checkpointing now)") {
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+    shuffleMapRdd.localCheckpoint()
+    assertResultStageFailToRollback(shuffleMapRdd)
+  }
+
+  private def assertResultStageNotRollbacked(mapRdd: MyRDD): Unit = {
+    val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2))
+    val shuffleId = shuffleDep.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, 
numShufflePartitions = 2)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === 
Some(Seq.empty))
+
+    // Finish the first task of the result stage
+    runEvent(makeCompletionEvent(
+      taskSets.last.tasks(0), Success, 42,
+      Seq.empty, createFakeTaskInfoWithId(0)))
+
+    // Fail the second task with FetchFailed.
+    runEvent(makeCompletionEvent(
+      taskSets.last.tasks(1),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+      null))
+
+    assert(failure == null, "job should not fail")
+    val failedStages = scheduler.failedStages.toSeq
+    assert(failedStages.length == 2)
+    // Shuffle blocks of "hostA" is lost, so first task of the 
`shuffleMapRdd2` needs to retry.
+    assert(failedStages.collect {
+      case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId 
=> stage
+    }.head.findMissingPartitions() == Seq(0))
+    // The first task of result stage remains completed.
+    assert(failedStages.collect {
+      case stage: ResultStage => stage
+    }.head.findMissingPartitions() == Seq(1))
+  }
+
+  test("SPARK-23207: reliable checkpoint can avoid rollback (checkpointed 
before)") {
+    sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath)
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+    shuffleMapRdd.checkpoint()
+    shuffleMapRdd.doCheckpoint()
+    assertResultStageNotRollbacked(shuffleMapRdd)
+  }
+
+  test("SPARK-23207: reliable checkpoint fail to rollback (checkpointing 
now)") {
+    sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath)
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true)
+    shuffleMapRdd.checkpoint()
+    assertResultStageFailToRollback(shuffleMapRdd)
+  }
+
   /**
    * Assert that the supplied TaskSet has exactly the given hosts as its 
preferred locations.
    * Note that this checks only the host and not the executor ID.

http://git-wip-us.apache.org/repos/asf/spark/blob/3158fc3c/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index c0ba513..4496afb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -247,6 +247,9 @@ object ShuffleExchange {
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")
     }
 
+    val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
+      newPartitioning.numPartitions > 1
+
     val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
       // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning 
is deterministic,
       // otherwise a retry task may output different rows and thus lead to 
data loss.
@@ -256,9 +259,7 @@ object ShuffleExchange {
       //
       // Note that we don't perform local sort if the new partitioning has 
only 1 partition, under
       // that case all output rows go to the same partition.
-      val newRdd = if (SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION) 
&&
-          newPartitioning.numPartitions > 1 &&
-          newPartitioning.isInstanceOf[RoundRobinPartitioning]) {
+      val newRdd = if (isRoundRobin && 
SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION)) {
         rdd.mapPartitionsInternal { iter =>
           val recordComparatorSupplier = new Supplier[RecordComparator] {
             override def get: RecordComparator = new RecordBinaryComparator()
@@ -294,17 +295,19 @@ object ShuffleExchange {
         rdd
       }
 
+      // round-robin function is order sensitive if we don't sort the input.
+      val isOrderSensitive = isRoundRobin && 
!SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION)
       if (needToCopyObjectsBeforeShuffle(part, serializer)) {
-        newRdd.mapPartitionsInternal { iter =>
+        newRdd.mapPartitionsWithIndexInternal((_, iter) => {
           val getPartitionKey = getPartitionKeyExtractor()
           iter.map { row => (part.getPartition(getPartitionKey(row)), 
row.copy()) }
-        }
+        }, isOrderSensitive = isOrderSensitive)
       } else {
-        newRdd.mapPartitionsInternal { iter =>
+        newRdd.mapPartitionsWithIndexInternal((_, iter) => {
           val getPartitionKey = getPartitionKeyExtractor()
           val mutablePair = new MutablePair[Int, InternalRow]()
           iter.map { row => 
mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
-        }
+        }, isOrderSensitive = isOrderSensitive)
       }
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to