mridulm commented on a change in pull request #33034:
URL: https://github.com/apache/spark/pull/33034#discussion_r676093603



##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
##########
@@ -91,15 +112,42 @@ public boolean shouldRetryError(Throwable t) {
           t.getCause() instanceof FileNotFoundException)) {
         return false;
       }
-      // If the block is too late, there is no need to retry it
-      return 
!Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX);
+      // If the block is too late or an stale block push, there is no need to 
retry it
+      return 
!(Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX) ||

Review comment:
       Pull `Throwables.getStackTraceAsString(t)` as a local variable and check 
against that instead of recomputing it three times, like in `shouldLogError` 
below

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
##########
@@ -81,6 +81,27 @@ default boolean shouldLogError(Throwable t) {
     public static final String IOEXCEPTIONS_EXCEEDED_THRESHOLD_PREFIX =
       "IOExceptions exceeded the threshold";
 
+    /**
+     * String constant used for generating exception messages indicating the 
server rejecting a block
+     * push since shuffle blocks of a higher shuffleMergeIdd for a shuffle is 
already being pushed.
+     * This typically happens in the case of indeterminate stage retries where 
if a stage attempt fails
+     * then the entirety of the shuffle output needs to be rolled back. For 
more details refer
+     * SPARK-23243, SPARK-25341 and SPARK-32923.
+     */
+    public static final String STALE_BLOCK_PUSH =
+        "stale block push as shuffle blocks of a higher shuffleMergeId for the 
shuffle is already being pushed";
+
+    /**
+     * String constant used for generating exception messages indicating the 
server rejecting a shuffle
+     * finalize request since shuffle blocks of a higher shuffleMergeId for a 
shuffle is already
+     * being pushed. This typically happens in the case of indeterminate stage 
retries where if a
+     * stage attempt fails then the entirety of the shuffle output needs to be 
rolled back. For more
+     * details refer SPARK-23243, SPARK-25341 and SPARK-32923.
+     */
+    public static final String STALE_SHUFFLE_FINALIZE =

Review comment:
       Is this expected to be a `_SUFFIX` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -191,36 +208,28 @@ private AbstractFetchShuffleBlocks 
createFetchShuffleMsgAndBuildBlockIds(
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
-    if (!areMergedChunks) {
-      long[] mapIds = Longs.toArray(primaryIds);
-      return new FetchShuffleBlocks(
-        appId, execId, shuffleId, mapIds, secondaryIdsArray, 
batchFetchEnabled);
-    } else {
-      int[] reduceIds = Ints.toArray(primaryIds);
-      return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, 
secondaryIdsArray);
-    }
+    return secondaryIds;
   }
 
-  /** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
+  /**
+   * Split the blockId and return accordingly
+   * shuffleChunk - return shuffleId, shuffleMergeId, reduceId and chunkIds
+   * shuffle block - return shuffleId, mapId, reduceId
+   * shuffle batch block - return shuffleId, mapId, begin reduceId and end 
reduceId
+   */
   private String[] splitBlockId(String blockId) {
     String[] blockIdParts = blockId.split("_");
     // For batch block id, the format contains shuffleId, mapId, begin 
reduceId, end reduceId.
-    // For single block id, the format contains shuffleId, mapId, educeId.
-    // For single block chunk id, the format contains shuffleId, reduceId, 
chunkId.
+    // For single block id, the format contains shuffleId, mapId, reduceId.
+    // For single block chunk id, the format contains shuffleId, 
shuffleMergeId, reduceId, chunkId.
     if (blockIdParts.length < 4 || blockIdParts.length > 5) {

Review comment:
       `blockIdParts.length == 5` ? We dont have `4` anymore, do we ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -380,9 +381,10 @@ protected Ratio getRatio() {
       } else if (blockId0Parts.length == 4 && 
blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) {

Review comment:
       The length should be 5 right ?
   If yes, we need new/updated tests for this - this should have been caught.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -135,51 +150,87 @@ protected AppShuffleInfo 
validateAndGetAppShuffleInfo(String appId) {
   private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo(
       AppShuffleInfo appShuffleInfo,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId) {
-    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, 
reduceId);
-    ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions =
+    ConcurrentMap<Integer, Map<Integer, Map<Integer, 
AppShufflePartitionInfo>>> partitions =
       appShuffleInfo.partitions;
-    Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-      partitions.compute(shuffleId, (id, map) -> {
-        if (map == null) {
-          // If this partition is already finalized then the partitions map 
will not contain the
-          // shuffleId but the data file would exist. In that case the block 
is considered late.
-          if (dataFile.exists()) {
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+      partitions.compute(shuffleId, (id, shuffleMergePartitionsMap) -> {
+        if (shuffleMergePartitionsMap == null) {
+          logger.info("Creating a new attempt for shuffle blocks push request 
for"
+              + " shuffle {} with shuffleMergeId {} for application {}_{}", 
shuffleId,
+              shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId));
+          Map<Integer, Map<Integer, AppShufflePartitionInfo>> 
newShuffleMergePartitions
+            = new ConcurrentHashMap<>();
+          Map<Integer, AppShufflePartitionInfo> newPartitionsMap = new 
ConcurrentHashMap<>();
+          newShuffleMergePartitions.put(shuffleMergeId, newPartitionsMap);
+          return newShuffleMergePartitions;
+        } else if (shuffleMergePartitionsMap.containsKey(shuffleMergeId)) {
+          return shuffleMergePartitionsMap;
+        } else {
+          int latestShuffleMergeId = 
shuffleMergePartitionsMap.keySet().stream()
+            .mapToInt(v -> v).max().orElse(UNDEFINED_SHUFFLE_MERGE_ID);
+          if (latestShuffleMergeId > shuffleMergeId) {
+            logger.info("Rejecting shuffle blocks push request for shuffle {} 
with"
+                + " shuffleMergeId {} for application {}_{} as a higher 
shuffleMergeId"
+                + " {} request is already seen", shuffleId, shuffleMergeId,
+                appShuffleInfo.appId, appShuffleInfo.attemptId, 
latestShuffleMergeId));
+            // Reject the request as we have already seen a higher 
shuffleMergeId than the
+            // current incoming one
             return null;

Review comment:
       This will end up removing `shuffleId` from `partitions` map.
   Please look up semantics of `compute` method for more details. 
   
   I would suggest throwing an Exception and handling it that when we have to 
reject stale pushes.
   Alternative would be to pass this via `AtomicBoolean` or `AtomicReference` - 
but continue to return `shuffleMergePartitionsMap`

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -270,18 +334,36 @@ public void applicationRemoved(String appId, boolean 
cleanupLocalDirs) {
   void closeAndDeletePartitionFilesIfNeeded(
       AppShuffleInfo appShuffleInfo,
       boolean cleanupLocalDirs) {
-    for (Map<Integer, AppShufflePartitionInfo> partitionMap : 
appShuffleInfo.partitions.values()) {
-      for (AppShufflePartitionInfo partitionInfo : partitionMap.values()) {
-        synchronized (partitionInfo) {
-          partitionInfo.closeAllFiles();
-        }
+    List<AppShufflePartitionInfo> partitionsToCleanUp =
+        appShuffleInfo
+          .partitions.values().stream()
+            .flatMap(x -> x.values().stream()).flatMap(x -> 
x.values().stream())
+              .collect(Collectors.toList());

Review comment:
       Replace `collect` with `forEach`

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -408,14 +507,24 @@ public MergeStatuses 
finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOExc
       throw new IllegalArgumentException(
         String.format("The attempt id %s in this FinalizeShuffleMerge message 
does not match "
           + "with the current attempt id %s stored in shuffle service for 
application %s",
-          msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
+            msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
+    }
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+        appShuffleInfo.partitions.get(msg.shuffleId);
+    Map<Integer, AppShufflePartitionInfo> shufflePartitions = 
shuffleMergePartitions.get(msg.shuffleMergeId);
+    if (shufflePartitions == STALE_SHUFFLE_PARTITIONS) {
+      throw new RuntimeException(String.format("Shuffle merge finalize request 
for shuffle %s"

Review comment:
       Handle `null == shufflePartitions`

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -649,13 +759,16 @@ public void onData(String streamId, ByteBuffer buf) 
throws IOException {
       // memory, while still providing the necessary guarantee.
       synchronized (partitionInfo) {
         Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-          appShuffleInfo.partitions.get(partitionInfo.shuffleId);
-        // If the partitionInfo corresponding to (appId, shuffleId, reduceId) 
is no longer present
-        // then it means that the shuffle merge has already been finalized. We 
should thus ignore
-        // the data and just drain the remaining bytes of this message. This 
check should be
-        // placed inside the synchronized block to make sure that checking the 
key is still
-        // present and processing the data is atomic.
-        if (shufflePartitions == null || 
!shufflePartitions.containsKey(partitionInfo.reduceId)) {
+          appShuffleInfo.partitions.get(partitionInfo.shuffleId)
+            .get(partitionInfo.shuffleMergeId);

Review comment:
       Handle `null`'s

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1450,11 +1450,11 @@ private[spark] object MapOutputTracker extends Logging {
           val remainingMapStatuses = if (mergeStatus != null && 
mergeStatus.totalSize > 0) {
             // If MergeStatus is available for the given partition, add 
location of the
             // pre-merged shuffle partition for this partition ID. Here we 
create a
-            // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate 
this is
+            // ShufflePushBlockId with mapId being SHUFFLE_PUSH_MAP_ID to 
indicate this is
             // a merged shuffle block.
             splitsByAddress.getOrElseUpdate(mergeStatus.location, 
ListBuffer()) +=
-              ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), 
mergeStatus.totalSize,
-                SHUFFLE_PUSH_MAP_ID))
+              ((ShuffleMergedBlockId(shuffleId, mergeStatus.shuffleMergeId, 
partId),
+                mergeStatus.totalSize, -1))

Review comment:
       `-1` -> `SHUFFLE_PUSH_MAP_ID`

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockId.scala
##########
@@ -172,11 +195,15 @@ object BlockId {
   val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
   val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
   val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
-  val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)".r
-  val SHUFFLE_MERGED_DATA = 
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).data".r
-  val SHUFFLE_MERGED_INDEX = 
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).index".r
-  val SHUFFLE_MERGED_META = 
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).meta".r
-  val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE_PUSH = "shufflePush_([0-9]+)_(-?[0-9]+)_([0-9]+)_([0-9]+)".r

Review comment:
       Currently, in this PR, we are not handling this as stated above.
   Even if we were to do it, I would prefer to use `0` for deterministic stages 
- and `> 0` for others : including the very first stage attempt execution for a 
non-deterministic stage.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -135,51 +150,87 @@ protected AppShuffleInfo 
validateAndGetAppShuffleInfo(String appId) {
   private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo(
       AppShuffleInfo appShuffleInfo,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId) {
-    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, 
reduceId);
-    ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions =
+    ConcurrentMap<Integer, Map<Integer, Map<Integer, 
AppShufflePartitionInfo>>> partitions =
       appShuffleInfo.partitions;
-    Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-      partitions.compute(shuffleId, (id, map) -> {
-        if (map == null) {
-          // If this partition is already finalized then the partitions map 
will not contain the
-          // shuffleId but the data file would exist. In that case the block 
is considered late.
-          if (dataFile.exists()) {
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+      partitions.compute(shuffleId, (id, shuffleMergePartitionsMap) -> {
+        if (shuffleMergePartitionsMap == null) {
+          logger.info("Creating a new attempt for shuffle blocks push request 
for"
+              + " shuffle {} with shuffleMergeId {} for application {}_{}", 
shuffleId,
+              shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId));
+          Map<Integer, Map<Integer, AppShufflePartitionInfo>> 
newShuffleMergePartitions
+            = new ConcurrentHashMap<>();
+          Map<Integer, AppShufflePartitionInfo> newPartitionsMap = new 
ConcurrentHashMap<>();
+          newShuffleMergePartitions.put(shuffleMergeId, newPartitionsMap);
+          return newShuffleMergePartitions;
+        } else if (shuffleMergePartitionsMap.containsKey(shuffleMergeId)) {
+          return shuffleMergePartitionsMap;
+        } else {
+          int latestShuffleMergeId = 
shuffleMergePartitionsMap.keySet().stream()
+            .mapToInt(v -> v).max().orElse(UNDEFINED_SHUFFLE_MERGE_ID);
+          if (latestShuffleMergeId > shuffleMergeId) {
+            logger.info("Rejecting shuffle blocks push request for shuffle {} 
with"
+                + " shuffleMergeId {} for application {}_{} as a higher 
shuffleMergeId"
+                + " {} request is already seen", shuffleId, shuffleMergeId,
+                appShuffleInfo.appId, appShuffleInfo.attemptId, 
latestShuffleMergeId));
+            // Reject the request as we have already seen a higher 
shuffleMergeId than the
+            // current incoming one
             return null;
+          } else {
+            // Higher shuffleMergeId seen for the shuffle ID meaning new stage 
attempt is being
+            // run for the shuffle ID. Close and clean up old shuffleMergeId 
files,
+            // happens in the non-deterministic stage retries
+            logger.info("Creating a new attempt for shuffle blocks push 
request for"
+               + " shuffle {} with shuffleMergeId {} for application {}_{} 
since it is"
+               + " higher than the latest shuffleMergeId {} already seen", 
shuffleId,
+               shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId,
+               latestShuffleMergeId));
+            if (null != shuffleMergePartitionsMap.get(latestShuffleMergeId)) {
+              Map<Integer, AppShufflePartitionInfo> shufflePartitions =
+                shuffleMergePartitionsMap.get(latestShuffleMergeId);
+              mergedShuffleCleaner.execute(() ->
+                closeAndDeletePartitionFiles(shufflePartitions));
+            }
+            shuffleMergePartitionsMap.put(latestShuffleMergeId, 
STALE_SHUFFLE_PARTITIONS);
+            Map<Integer, AppShufflePartitionInfo> newPartitionsMap = new 
ConcurrentHashMap<>();
+            shuffleMergePartitionsMap.put(shuffleMergeId, newPartitionsMap);
+            return shuffleMergePartitionsMap;
           }
-          return new ConcurrentHashMap<>();
-        } else {
-          return map;
         }
       });
-    if (shufflePartitions == null) {
+
+    Map<Integer, AppShufflePartitionInfo> shufflePartitions = 
shuffleMergePartitions.get(shuffleMergeId);
+    if (shufflePartitions == FINALIZED_SHUFFLE_PARTITIONS
+        || shufflePartitions == STALE_SHUFFLE_PARTITIONS) {
+      // It only gets here when shufflePartitions is either 
FINALIZED_SHUFFLE_PARTITIONS or STALE_SHUFFLE_PARTITIONS.
+      // This happens in 2 cases:
+      // 1. Incoming block request is for an older shuffleMergeId of a shuffle 
(i.e already higher shuffle
+      // sequence Id blocks are being merged for this shuffle Id.
+      // 2. Shuffle for the current shuffleMergeId is already finalized.
       return null;
     }
 
+    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, 
shuffleMergeId, reduceId);
     return shufflePartitions.computeIfAbsent(reduceId, key -> {
-      // It only gets here when the key is not present in the map. This could 
either
-      // be the first time the merge manager receives a pushed block for a 
given application
-      // shuffle partition, or after the merged shuffle file is finalized. We 
handle these
-      // two cases accordingly by checking if the file already exists.
+      // It only gets here when the key is not present in the map. The first 
time the merge
+      // manager receives a pushed block for a given application shuffle 
partition.
       File indexFile =
-        appShuffleInfo.getMergedShuffleIndexFile(shuffleId, reduceId);
+        appShuffleInfo.getMergedShuffleIndexFile(shuffleId, shuffleMergeId, 
reduceId);
       File metaFile =
-        appShuffleInfo.getMergedShuffleMetaFile(shuffleId, reduceId);
+        appShuffleInfo.getMergedShuffleMetaFile(shuffleId, shuffleMergeId, 
reduceId);
       try {
-        if (dataFile.exists()) {
-          return null;
-        } else {
-          return newAppShufflePartitionInfo(
-            appShuffleInfo.appId, shuffleId, reduceId, dataFile, indexFile, 
metaFile);
-        }
+        return newAppShufflePartitionInfo(appShuffleInfo.appId, shuffleId, 
shuffleMergeId,
+          reduceId, dataFile, indexFile, metaFile);
       } catch (IOException e) {
         logger.error(
           "Cannot create merged shuffle partition with data file {}, index 
file {}, and "
             + "meta file {}", dataFile.getAbsolutePath(),
             indexFile.getAbsolutePath(), metaFile.getAbsolutePath());
         throw new RuntimeException(
           String.format("Cannot initialize merged shuffle partition for appId 
%s shuffleId %s "
-            + "reduceId %s", appShuffleInfo.appId, shuffleId, reduceId), e);
+            + "shuffleMergeId %s reduceId %s", appShuffleInfo.appId, 
shuffleId, shuffleMergeId, reduceId), e);
       }
     });

Review comment:
       I am not sure of the removal of the file existence checks.
   +CC @zhouyejoe, @otterc 

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
##########
@@ -81,6 +81,27 @@ default boolean shouldLogError(Throwable t) {
     public static final String IOEXCEPTIONS_EXCEEDED_THRESHOLD_PREFIX =
       "IOExceptions exceeded the threshold";
 
+    /**
+     * String constant used for generating exception messages indicating the 
server rejecting a block
+     * push since shuffle blocks of a higher shuffleMergeIdd for a shuffle is 
already being pushed.
+     * This typically happens in the case of indeterminate stage retries where 
if a stage attempt fails
+     * then the entirety of the shuffle output needs to be rolled back. For 
more details refer
+     * SPARK-23243, SPARK-25341 and SPARK-32923.
+     */
+    public static final String STALE_BLOCK_PUSH =

Review comment:
       Is this expected to be a `_SUFFIX` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -380,9 +381,10 @@ protected Ratio getRatio() {
       } else if (blockId0Parts.length == 4 && 
blockId0Parts[0].equals(SHUFFLE_CHUNK_ID)) {
         requestForMergedBlockChunks = true;
         final int shuffleId = Integer.parseInt(blockId0Parts[1]);
+        final int shuffleMergeId = Integer.parseInt(blockId0Parts[2]);
         final int[] reduceIdAndChunkIds = shuffleMapIdAndReduceIds(blockIds, 
shuffleId);

Review comment:
       If I am reading this correctly, `shuffleMapIdAndReduceIds` is not 
handling introduction of `shuffleMergeId` into the block name pattern ?
   If yes, why are tests not catching this ?
   
   +CC @otterc who wrote this initially to take a closer look.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -379,7 +473,12 @@ public void onData(String streamId, ByteBuffer buf) {
 
         @Override
         public void onComplete(String streamId) {
-          if (isTooLate) {
+          if (isStaleBlock) {
+            // Throw an exception here so the block data is drained from 
channel and server
+            // responds RpcFailure to the client.
+            throw new RuntimeException(String.format("Block %s %s is", 
streamId,

Review comment:
       The `is` suffix should be before the second `%s` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
##########
@@ -91,15 +112,42 @@ public boolean shouldRetryError(Throwable t) {
           t.getCause() instanceof FileNotFoundException)) {
         return false;
       }
-      // If the block is too late, there is no need to retry it
-      return 
!Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX);
+      // If the block is too late or an stale block push, there is no need to 
retry it
+      return 
!(Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX) ||
+          Throwables.getStackTraceAsString(t).contains(STALE_BLOCK_PUSH) ||
+          
Throwables.getStackTraceAsString(t).contains(STALE_SHUFFLE_FINALIZE));
     }
 
     @Override
     public boolean shouldLogError(Throwable t) {
       String errorStackTrace = Throwables.getStackTraceAsString(t);
       return 
!errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) &&
-        !errorStackTrace.contains(TOO_LATE_MESSAGE_SUFFIX);
+        !errorStackTrace.contains(TOO_LATE_MESSAGE_SUFFIX) &&
+          !errorStackTrace.contains(STALE_BLOCK_PUSH) &&
+            !errorStackTrace.contains(STALE_SHUFFLE_FINALIZE);
+    }
+  }
+
+  class BlockFetchErrorHandler implements ErrorHandler {
+    /**
+     * String constant used for generating exception messages indicating the 
server rejecting a block
+     * fetch since shuffle blocks of a higher shuffleMergeId for a shuffle is 
already found.
+     * This typically happens in the case of indeterminate stage retries where 
if a stage attempt fails
+     * then the entirety of the shuffle output needs to be rolled back. For 
more details refer
+     * SPARK-23243 and SPARK-25341.
+     */
+    public static final String STALE_BLOCK_FETCH =

Review comment:
       Should this be a `_SUFFIX` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -125,63 +127,78 @@ private AbstractFetchShuffleBlocks 
createFetchShuffleBlocksOrChunksMsg(
       String execId,
       String[] blockIds) {
     if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
-      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+      return createFetchShuffleChunksMsg(appId, execId, blockIds);
     } else {
-      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+      return createFetchShuffleBlocksMsg(appId, execId, blockIds);
     }
   }
 
-  /**
-   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
-   * analyzing the passed in blockIds.
-   */
-  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
-      String appId,
-      String execId,
-      String[] blockIds,
-      boolean areMergedChunks) {
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksMsg(String appId, 
String execId, String[] blockIds) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
-
-    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
-    // is reduceId.
-    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
+    Map<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
-        throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
-          ", got:" + blockId);
+        throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + 
", got:" + blockId);
       }
-      Number primaryId;
-      if (!areMergedChunks) {
-        primaryId = Long.parseLong(blockIdParts[2]);
-      } else {
-        primaryId = Integer.parseInt(blockIdParts[2]);
-      }
-      BlocksInfo blocksInfoByPrimaryId = 
primaryIdToBlocksInfo.computeIfAbsent(primaryId,
-        id -> new BlocksInfo());
-      blocksInfoByPrimaryId.blockIds.add(blockId);
-      // If blockId is a regular shuffle block, then blockIdParts[3] = 
reduceId. If blockId is a
-      // shuffleChunk block, then blockIdParts[3] = chunkId
-      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
+
+      long mapId = Long.parseLong(blockIdParts[2]);
+      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.computeIfAbsent(mapId,
+          id -> new BlocksInfo());
+      blocksInfoByMapId.blockIds.add(blockId);
+      blocksInfoByMapId.ids.add(Integer.parseInt(blockIdParts[3]));
+
       if (batchFetchEnabled) {
-        // It comes here only if the blockId is a regular shuffle block not a 
shuffleChunk block.
         // When we read continuous shuffle blocks in batch, we will reuse 
reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
         // blockIdParts[4] is the end reduce id for the batch range
-        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByMapId.ids.add(Integer.parseInt(blockIdParts[4]));
+      }
+    }
+
+    int[][] reduceIdsArray = getSecondaryIds(mapIdToBlocksInfo);
+    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
+    return new FetchShuffleBlocks(
+        appId, execId, shuffleId, mapIds, reduceIdsArray, batchFetchEnabled);
+  }
+
+  private AbstractFetchShuffleBlocks createFetchShuffleChunksMsg(String appId, 
String execId, String[] blockIds) {
+    String[] firstBlock = splitBlockId(blockIds[0]);
+    int shuffleId = Integer.parseInt(firstBlock[1]);
+    int shuffleMergeId = Integer.parseInt(firstBlock[2]);
+
+    Map<Integer, BlocksInfo> reduceIdToBlocksInfo = new LinkedHashMap<>();
+    for (String blockId : blockIds) {
+      String[] blockIdParts = splitBlockId(blockId);
+      if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
+        throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + 
", got:" + blockId);
       }
+

Review comment:
       Validate `shuffleMergeId` as well, like `shuffleId` above.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
##########
@@ -91,15 +112,42 @@ public boolean shouldRetryError(Throwable t) {
           t.getCause() instanceof FileNotFoundException)) {
         return false;
       }
-      // If the block is too late, there is no need to retry it
-      return 
!Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX);
+      // If the block is too late or an stale block push, there is no need to 
retry it
+      return 
!(Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX) ||
+          Throwables.getStackTraceAsString(t).contains(STALE_BLOCK_PUSH) ||
+          
Throwables.getStackTraceAsString(t).contains(STALE_SHUFFLE_FINALIZE));
     }
 
     @Override
     public boolean shouldLogError(Throwable t) {
       String errorStackTrace = Throwables.getStackTraceAsString(t);
       return 
!errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) &&
-        !errorStackTrace.contains(TOO_LATE_MESSAGE_SUFFIX);
+        !errorStackTrace.contains(TOO_LATE_MESSAGE_SUFFIX) &&
+          !errorStackTrace.contains(STALE_BLOCK_PUSH) &&
+            !errorStackTrace.contains(STALE_SHUFFLE_FINALIZE);
+    }
+  }
+
+  class BlockFetchErrorHandler implements ErrorHandler {
+    /**
+     * String constant used for generating exception messages indicating the 
server rejecting a block
+     * fetch since shuffle blocks of a higher shuffleMergeId for a shuffle is 
already found.
+     * This typically happens in the case of indeterminate stage retries where 
if a stage attempt fails
+     * then the entirety of the shuffle output needs to be rolled back. For 
more details refer
+     * SPARK-23243 and SPARK-25341.
+     */
+    public static final String STALE_BLOCK_FETCH =
+        "stale fetch as the shuffleMergeId is older than the latest 
shuffleMergeId";
+
+    @Override
+    public boolean shouldRetryError(Throwable t) {
+      return !Throwables.getStackTraceAsString(t).contains(STALE_BLOCK_FETCH);
+    }
+
+    @Override
+    public boolean shouldLogError(Throwable t) {
+      String errorStackTrace = Throwables.getStackTraceAsString(t);
+      return !errorStackTrace.contains(STALE_BLOCK_FETCH);

Review comment:
       nit: inline and remove local variable ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -78,6 +80,19 @@
   public static final String MERGE_DIR_KEY = "mergeDir";
   public static final String ATTEMPT_ID_KEY = "attemptId";
   private static final int UNDEFINED_ATTEMPT_ID = -1;
+  private static final int UNDEFINED_SHUFFLE_MERGE_ID = Integer.MIN_VALUE;
+
+  // ConcurrentHashMap doesn't allow null for keys or values which is why this 
is required.
+  // Marker to identify stale shuffle partitions typically happens in the case 
of
+  // indeterminate stage retries.
+  @VisibleForTesting
+  public static final Map<Integer, AppShufflePartitionInfo> 
STALE_SHUFFLE_PARTITIONS =
+    new ConcurrentHashMap<>();

Review comment:
       Make this `Collections.emptyMap` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -408,14 +507,24 @@ public MergeStatuses 
finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOExc
       throw new IllegalArgumentException(
         String.format("The attempt id %s in this FinalizeShuffleMerge message 
does not match "
           + "with the current attempt id %s stored in shuffle service for 
application %s",
-          msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
+            msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
+    }
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+        appShuffleInfo.partitions.get(msg.shuffleId);
+    Map<Integer, AppShufflePartitionInfo> shufflePartitions = 
shuffleMergePartitions.get(msg.shuffleMergeId);

Review comment:
       Handle `null == shuffleMergePartitions`

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -353,18 +443,22 @@ public StreamCallbackWithID 
receiveBlockDataAsStream(PushBlockStream msg) {
     // getting killed. When this happens, we need to distinguish the duplicate 
blocks as they
     // arrive. More details on this is explained in later comments.
 
+    // Track if the block is received from an older shuffleMergeId attempt.
+    final boolean isStaleBlock = partitionInfoBeforeCheck == 
STALE_SHUFFLE_PARTITIONS;
     // Track if the block is received after shuffle merge finalize
-    final boolean isTooLate = partitionInfoBeforeCheck == null;
-    // Check if the given block is already merged by checking the bitmap 
against the given map index
-    final AppShufflePartitionInfo partitionInfo = partitionInfoBeforeCheck != 
null
-      && partitionInfoBeforeCheck.mapTracker.contains(msg.mapIndex) ? null
-        : partitionInfoBeforeCheck;
+    final boolean isTooLate = partitionInfoBeforeCheck == 
FINALIZED_SHUFFLE_PARTITIONS;
+    // Check if the given block is already merged by checking the bitmap 
against the given map
+    // index
+    final boolean isStaleOrTooLate = (partitionInfoBeforeCheck == 
STALE_SHUFFLE_PARTITIONS ||
+        partitionInfoBeforeCheck == FINALIZED_SHUFFLE_PARTITIONS);

Review comment:
       `isTooLate || isStaleBlock` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -135,51 +150,87 @@ protected AppShuffleInfo 
validateAndGetAppShuffleInfo(String appId) {
   private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo(
       AppShuffleInfo appShuffleInfo,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId) {
-    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, 
reduceId);
-    ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions =
+    ConcurrentMap<Integer, Map<Integer, Map<Integer, 
AppShufflePartitionInfo>>> partitions =
       appShuffleInfo.partitions;
-    Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-      partitions.compute(shuffleId, (id, map) -> {
-        if (map == null) {
-          // If this partition is already finalized then the partitions map 
will not contain the
-          // shuffleId but the data file would exist. In that case the block 
is considered late.
-          if (dataFile.exists()) {
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+      partitions.compute(shuffleId, (id, shuffleMergePartitionsMap) -> {
+        if (shuffleMergePartitionsMap == null) {
+          logger.info("Creating a new attempt for shuffle blocks push request 
for"
+              + " shuffle {} with shuffleMergeId {} for application {}_{}", 
shuffleId,
+              shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId));
+          Map<Integer, Map<Integer, AppShufflePartitionInfo>> 
newShuffleMergePartitions
+            = new ConcurrentHashMap<>();
+          Map<Integer, AppShufflePartitionInfo> newPartitionsMap = new 
ConcurrentHashMap<>();
+          newShuffleMergePartitions.put(shuffleMergeId, newPartitionsMap);
+          return newShuffleMergePartitions;
+        } else if (shuffleMergePartitionsMap.containsKey(shuffleMergeId)) {
+          return shuffleMergePartitionsMap;
+        } else {
+          int latestShuffleMergeId = 
shuffleMergePartitionsMap.keySet().stream()
+            .mapToInt(v -> v).max().orElse(UNDEFINED_SHUFFLE_MERGE_ID);
+          if (latestShuffleMergeId > shuffleMergeId) {
+            logger.info("Rejecting shuffle blocks push request for shuffle {} 
with"
+                + " shuffleMergeId {} for application {}_{} as a higher 
shuffleMergeId"
+                + " {} request is already seen", shuffleId, shuffleMergeId,
+                appShuffleInfo.appId, appShuffleInfo.attemptId, 
latestShuffleMergeId));
+            // Reject the request as we have already seen a higher 
shuffleMergeId than the
+            // current incoming one
             return null;
+          } else {
+            // Higher shuffleMergeId seen for the shuffle ID meaning new stage 
attempt is being
+            // run for the shuffle ID. Close and clean up old shuffleMergeId 
files,
+            // happens in the non-deterministic stage retries
+            logger.info("Creating a new attempt for shuffle blocks push 
request for"
+               + " shuffle {} with shuffleMergeId {} for application {}_{} 
since it is"
+               + " higher than the latest shuffleMergeId {} already seen", 
shuffleId,
+               shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId,
+               latestShuffleMergeId));
+            if (null != shuffleMergePartitionsMap.get(latestShuffleMergeId)) {
+              Map<Integer, AppShufflePartitionInfo> shufflePartitions =
+                shuffleMergePartitionsMap.get(latestShuffleMergeId);
+              mergedShuffleCleaner.execute(() ->
+                closeAndDeletePartitionFiles(shufflePartitions));
+            }
+            shuffleMergePartitionsMap.put(latestShuffleMergeId, 
STALE_SHUFFLE_PARTITIONS);

Review comment:
       Why are we keeping all previous entries ? Just the latest 
(shuffleMergeId) should do ?
   
   Also, assuming current code is correct:
   ```suggestion
               Map<Integer, AppShufflePartitionInfo> latestShufflePartitions = 
shuffleMergePartitionsMap.get(latestShuffleMergeId);
               // latestShuffleMergeId cannot be UNDEFINED_SHUFFLE_MERGE_ID - 
since shuffle merge id is specific to a shuffle id.
               assert (UNDEFINED_SHUFFLE_MERGE_ID != latestShuffleMergeId);
               assert (null != latestShufflePartitions);
               
               if (STALE_SHUFFLE_PARTITIONS != latestShufflePartitions) {
                 mergedShuffleCleaner.execute(() ->
                   closeAndDeletePartitionFiles(latestShufflePartitions));
                 shuffleMergePartitionsMap.put(latestShuffleMergeId, 
STALE_SHUFFLE_PARTITIONS);
               }
   ```

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -135,51 +150,87 @@ protected AppShuffleInfo 
validateAndGetAppShuffleInfo(String appId) {
   private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo(
       AppShuffleInfo appShuffleInfo,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId) {
-    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, 
reduceId);
-    ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions =
+    ConcurrentMap<Integer, Map<Integer, Map<Integer, 
AppShufflePartitionInfo>>> partitions =
       appShuffleInfo.partitions;
-    Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-      partitions.compute(shuffleId, (id, map) -> {
-        if (map == null) {
-          // If this partition is already finalized then the partitions map 
will not contain the
-          // shuffleId but the data file would exist. In that case the block 
is considered late.
-          if (dataFile.exists()) {
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+      partitions.compute(shuffleId, (id, shuffleMergePartitionsMap) -> {
+        if (shuffleMergePartitionsMap == null) {
+          logger.info("Creating a new attempt for shuffle blocks push request 
for"
+              + " shuffle {} with shuffleMergeId {} for application {}_{}", 
shuffleId,
+              shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId));
+          Map<Integer, Map<Integer, AppShufflePartitionInfo>> 
newShuffleMergePartitions
+            = new ConcurrentHashMap<>();
+          Map<Integer, AppShufflePartitionInfo> newPartitionsMap = new 
ConcurrentHashMap<>();
+          newShuffleMergePartitions.put(shuffleMergeId, newPartitionsMap);
+          return newShuffleMergePartitions;
+        } else if (shuffleMergePartitionsMap.containsKey(shuffleMergeId)) {
+          return shuffleMergePartitionsMap;
+        } else {
+          int latestShuffleMergeId = 
shuffleMergePartitionsMap.keySet().stream()
+            .mapToInt(v -> v).max().orElse(UNDEFINED_SHUFFLE_MERGE_ID);
+          if (latestShuffleMergeId > shuffleMergeId) {
+            logger.info("Rejecting shuffle blocks push request for shuffle {} 
with"
+                + " shuffleMergeId {} for application {}_{} as a higher 
shuffleMergeId"
+                + " {} request is already seen", shuffleId, shuffleMergeId,
+                appShuffleInfo.appId, appShuffleInfo.attemptId, 
latestShuffleMergeId));
+            // Reject the request as we have already seen a higher 
shuffleMergeId than the
+            // current incoming one
             return null;
+          } else {
+            // Higher shuffleMergeId seen for the shuffle ID meaning new stage 
attempt is being
+            // run for the shuffle ID. Close and clean up old shuffleMergeId 
files,
+            // happens in the non-deterministic stage retries
+            logger.info("Creating a new attempt for shuffle blocks push 
request for"
+               + " shuffle {} with shuffleMergeId {} for application {}_{} 
since it is"
+               + " higher than the latest shuffleMergeId {} already seen", 
shuffleId,
+               shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId,
+               latestShuffleMergeId));
+            if (null != shuffleMergePartitionsMap.get(latestShuffleMergeId)) {
+              Map<Integer, AppShufflePartitionInfo> shufflePartitions =
+                shuffleMergePartitionsMap.get(latestShuffleMergeId);
+              mergedShuffleCleaner.execute(() ->
+                closeAndDeletePartitionFiles(shufflePartitions));
+            }
+            shuffleMergePartitionsMap.put(latestShuffleMergeId, 
STALE_SHUFFLE_PARTITIONS);
+            Map<Integer, AppShufflePartitionInfo> newPartitionsMap = new 
ConcurrentHashMap<>();
+            shuffleMergePartitionsMap.put(shuffleMergeId, newPartitionsMap);
+            return shuffleMergePartitionsMap;
           }
-          return new ConcurrentHashMap<>();
-        } else {
-          return map;
         }
       });
-    if (shufflePartitions == null) {
+
+    Map<Integer, AppShufflePartitionInfo> shufflePartitions = 
shuffleMergePartitions.get(shuffleMergeId);

Review comment:
       This can race with the update above - pass it from the `compute` above 
via an `AtomicReference` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -135,51 +150,87 @@ protected AppShuffleInfo 
validateAndGetAppShuffleInfo(String appId) {
   private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo(
       AppShuffleInfo appShuffleInfo,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId) {
-    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, 
reduceId);
-    ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions =
+    ConcurrentMap<Integer, Map<Integer, Map<Integer, 
AppShufflePartitionInfo>>> partitions =
       appShuffleInfo.partitions;
-    Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-      partitions.compute(shuffleId, (id, map) -> {
-        if (map == null) {
-          // If this partition is already finalized then the partitions map 
will not contain the
-          // shuffleId but the data file would exist. In that case the block 
is considered late.
-          if (dataFile.exists()) {
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+      partitions.compute(shuffleId, (id, shuffleMergePartitionsMap) -> {
+        if (shuffleMergePartitionsMap == null) {
+          logger.info("Creating a new attempt for shuffle blocks push request 
for"
+              + " shuffle {} with shuffleMergeId {} for application {}_{}", 
shuffleId,
+              shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId));
+          Map<Integer, Map<Integer, AppShufflePartitionInfo>> 
newShuffleMergePartitions
+            = new ConcurrentHashMap<>();
+          Map<Integer, AppShufflePartitionInfo> newPartitionsMap = new 
ConcurrentHashMap<>();
+          newShuffleMergePartitions.put(shuffleMergeId, newPartitionsMap);
+          return newShuffleMergePartitions;
+        } else if (shuffleMergePartitionsMap.containsKey(shuffleMergeId)) {
+          return shuffleMergePartitionsMap;
+        } else {
+          int latestShuffleMergeId = 
shuffleMergePartitionsMap.keySet().stream()
+            .mapToInt(v -> v).max().orElse(UNDEFINED_SHUFFLE_MERGE_ID);
+          if (latestShuffleMergeId > shuffleMergeId) {
+            logger.info("Rejecting shuffle blocks push request for shuffle {} 
with"
+                + " shuffleMergeId {} for application {}_{} as a higher 
shuffleMergeId"
+                + " {} request is already seen", shuffleId, shuffleMergeId,
+                appShuffleInfo.appId, appShuffleInfo.attemptId, 
latestShuffleMergeId));
+            // Reject the request as we have already seen a higher 
shuffleMergeId than the
+            // current incoming one
             return null;
+          } else {
+            // Higher shuffleMergeId seen for the shuffle ID meaning new stage 
attempt is being
+            // run for the shuffle ID. Close and clean up old shuffleMergeId 
files,
+            // happens in the non-deterministic stage retries
+            logger.info("Creating a new attempt for shuffle blocks push 
request for"
+               + " shuffle {} with shuffleMergeId {} for application {}_{} 
since it is"
+               + " higher than the latest shuffleMergeId {} already seen", 
shuffleId,
+               shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId,
+               latestShuffleMergeId));
+            if (null != shuffleMergePartitionsMap.get(latestShuffleMergeId)) {

Review comment:
       Why would this be `null` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -78,6 +80,19 @@
   public static final String MERGE_DIR_KEY = "mergeDir";
   public static final String ATTEMPT_ID_KEY = "attemptId";
   private static final int UNDEFINED_ATTEMPT_ID = -1;
+  private static final int UNDEFINED_SHUFFLE_MERGE_ID = Integer.MIN_VALUE;
+
+  // ConcurrentHashMap doesn't allow null for keys or values which is why this 
is required.
+  // Marker to identify stale shuffle partitions typically happens in the case 
of
+  // indeterminate stage retries.
+  @VisibleForTesting
+  public static final Map<Integer, AppShufflePartitionInfo> 
STALE_SHUFFLE_PARTITIONS =
+    new ConcurrentHashMap<>();
+
+  // Marker for finalized shuffle partitions, used to identify late blocks 
getting merged.
+  @VisibleForTesting
+  public static final Map<Integer, AppShufflePartitionInfo> 
FINALIZED_SHUFFLE_PARTITIONS =
+    new ConcurrentHashMap<>();

Review comment:
       Make this also `Collections.emptyMap` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -135,51 +150,87 @@ protected AppShuffleInfo 
validateAndGetAppShuffleInfo(String appId) {
   private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo(
       AppShuffleInfo appShuffleInfo,
       int shuffleId,
+      int shuffleMergeId,
       int reduceId) {
-    File dataFile = appShuffleInfo.getMergedShuffleDataFile(shuffleId, 
reduceId);
-    ConcurrentMap<Integer, Map<Integer, AppShufflePartitionInfo>> partitions =
+    ConcurrentMap<Integer, Map<Integer, Map<Integer, 
AppShufflePartitionInfo>>> partitions =
       appShuffleInfo.partitions;
-    Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-      partitions.compute(shuffleId, (id, map) -> {
-        if (map == null) {
-          // If this partition is already finalized then the partitions map 
will not contain the
-          // shuffleId but the data file would exist. In that case the block 
is considered late.
-          if (dataFile.exists()) {
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+      partitions.compute(shuffleId, (id, shuffleMergePartitionsMap) -> {
+        if (shuffleMergePartitionsMap == null) {
+          logger.info("Creating a new attempt for shuffle blocks push request 
for"
+              + " shuffle {} with shuffleMergeId {} for application {}_{}", 
shuffleId,
+              shuffleMergeId, appShuffleInfo.appId, appShuffleInfo.attemptId));

Review comment:
       This line would not compile, can you please rerun build/tests locally ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -649,13 +759,16 @@ public void onData(String streamId, ByteBuffer buf) 
throws IOException {
       // memory, while still providing the necessary guarantee.
       synchronized (partitionInfo) {
         Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-          appShuffleInfo.partitions.get(partitionInfo.shuffleId);
-        // If the partitionInfo corresponding to (appId, shuffleId, reduceId) 
is no longer present
-        // then it means that the shuffle merge has already been finalized. We 
should thus ignore
-        // the data and just drain the remaining bytes of this message. This 
check should be
-        // placed inside the synchronized block to make sure that checking the 
key is still
-        // present and processing the data is atomic.
-        if (shufflePartitions == null || 
!shufflePartitions.containsKey(partitionInfo.reduceId)) {
+          appShuffleInfo.partitions.get(partitionInfo.shuffleId)
+            .get(partitionInfo.shuffleMergeId);
+        // If the partitionInfo corresponding to (appId, shuffleId, 
shuffleMergeId, reduceId)
+        // is either set to STALE_SHUFFLE_PARTITIONS or 
FINALIZED_SHUFFLE_PARTITIONS then it
+        // means that the stream request is for an older shuffleMergeId or the 
shuffle is
+        // finalized. We should thus ignore the data and just drain the 
remaining bytes of this
+        // message. This check should be placed inside the synchronized block 
to make sure that
+        // checking the key is still present and processing the data is atomic.
+        if (shufflePartitions == STALE_SHUFFLE_PARTITIONS ||
+            shufflePartitions == FINALIZED_SHUFFLE_PARTITIONS) {

Review comment:
       nit: Create a `isStaleOrTooLate(shufflePartitions)` util method

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -408,14 +507,24 @@ public MergeStatuses 
finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOExc
       throw new IllegalArgumentException(
         String.format("The attempt id %s in this FinalizeShuffleMerge message 
does not match "
           + "with the current attempt id %s stored in shuffle service for 
application %s",
-          msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
+            msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
+    }
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+        appShuffleInfo.partitions.get(msg.shuffleId);
+    Map<Integer, AppShufflePartitionInfo> shufflePartitions = 
shuffleMergePartitions.get(msg.shuffleMergeId);
+    if (shufflePartitions == STALE_SHUFFLE_PARTITIONS) {
+      throw new RuntimeException(String.format("Shuffle merge finalize request 
for shuffle %s"
+        + " with shuffleMergeId %s is %s", msg.shuffleId, msg.shuffleMergeId,
+          ErrorHandler.BlockPushErrorHandler.STALE_SHUFFLE_FINALIZE));
+    } else {
+      shuffleMergePartitions.put(msg.shuffleMergeId, 
FINALIZED_SHUFFLE_PARTITIONS);
+      appShuffleInfo.partitions.put(msg.shuffleId, shuffleMergePartitions);

Review comment:
       Why `appShuffleInfo.partitions.put` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java
##########
@@ -68,28 +72,31 @@ public boolean equals(Object other) {
     if (other != null && other instanceof FinalizeShuffleMerge) {
       FinalizeShuffleMerge o = (FinalizeShuffleMerge) other;
       return Objects.equal(appId, o.appId)
-        && appAttemptId == o.appAttemptId
-        && shuffleId == o.shuffleId;
+        && appAttemptId == appAttemptId

Review comment:
       This was fixed in master - please update to latest.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -408,14 +507,24 @@ public MergeStatuses 
finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOExc
       throw new IllegalArgumentException(
         String.format("The attempt id %s in this FinalizeShuffleMerge message 
does not match "
           + "with the current attempt id %s stored in shuffle service for 
application %s",
-          msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
+            msg.appAttemptId, appShuffleInfo.attemptId, msg.appId));
+    }
+    Map<Integer, Map<Integer, AppShufflePartitionInfo>> shuffleMergePartitions 
=
+        appShuffleInfo.partitions.get(msg.shuffleId);
+    Map<Integer, AppShufflePartitionInfo> shufflePartitions = 
shuffleMergePartitions.get(msg.shuffleMergeId);
+    if (shufflePartitions == STALE_SHUFFLE_PARTITIONS) {
+      throw new RuntimeException(String.format("Shuffle merge finalize request 
for shuffle %s"
+        + " with shuffleMergeId %s is %s", msg.shuffleId, msg.shuffleMergeId,
+          ErrorHandler.BlockPushErrorHandler.STALE_SHUFFLE_FINALIZE));
+    } else {
+      shuffleMergePartitions.put(msg.shuffleMergeId, 
FINALIZED_SHUFFLE_PARTITIONS);
+      appShuffleInfo.partitions.put(msg.shuffleId, shuffleMergePartitions);
     }
-    Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-      appShuffleInfo.partitions.remove(msg.shuffleId);

Review comment:
       What is the rationale behind not removing old shuffle id's ? (if this is 
the latest shuffle).
   How long are we planning to keep them around ? This is effectively a memory 
leak.
   
   Having said that, if you do want to do this, rewrite code as:
   
   ```
   AtomicReference<Map<Integer, AppShufflePartitionInfo>> shufflePartitions = 
new AtomicReference<>(null);
   shuffleMergePartitions.compute(msg.shuffleMergeId, (id, map) -> {
     if (null == map) { ... }
     else if (FINALIZED_SHUFFLE_PARTITIONS == map) { ... }
     else {
       if (STALE_SHUFFLE_PARTITIONS == map) { ... } 
       shufflePartitions.set(map);
       return FINALIZED_SHUFFLE_PARTITIONS;
     }
   } );
   ```
   

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java
##########
@@ -725,15 +838,26 @@ public void onData(String streamId, ByteBuffer buf) 
throws IOException {
     @Override
     public void onComplete(String streamId) throws IOException {
       synchronized (partitionInfo) {
-        logger.trace("{} shuffleId {} reduceId {} onComplete invoked",
-          partitionInfo.appId, partitionInfo.shuffleId,
+        logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onComplete 
invoked",
+          partitionInfo.appId, partitionInfo.shuffleId, 
partitionInfo.shuffleMergeId,
           partitionInfo.reduceId);
         Map<Integer, AppShufflePartitionInfo> shufflePartitions =
-          appShuffleInfo.partitions.get(partitionInfo.shuffleId);
+          appShuffleInfo.partitions.get(partitionInfo.shuffleId)
+              .get(partitionInfo.shuffleMergeId);

Review comment:
       Handle `null`'s
   Essentially, any get on `appShuffleInfo.partitions` can return `null`. 
Similarly, a `get` on the returned map can also be `null`  - please handle it.
   

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
##########
@@ -68,13 +72,13 @@ public boolean equals(Object o) {
 
     FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o;
     if (!super.equals(that)) return false;
-    if (!Arrays.equals(reduceIds, that.reduceIds)) return false;
+    if (shuffleMergeId != that.shuffleMergeId || !Arrays.equals(reduceIds, 
that.reduceIds)) return false;
     return Arrays.deepEquals(chunkIds, that.chunkIds);
   }
 
   @Override
   public int hashCode() {
-    int result = super.hashCode();
+    int result = super.hashCode() + shuffleMergeId;

Review comment:
       nit: `super.hashCode() * 31 + shuffleMergeId`

##########
File path: core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
##########
@@ -69,16 +69,19 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) 
extends Logging {
     new BlockPushErrorHandler() {
       // For a connection exception against a particular host, we will stop 
pushing any
       // blocks to just that host and continue push blocks to other hosts. So, 
here push of
-      // all blocks will only stop when it is "Too Late". Also see 
updateStateAndCheckIfPushMore.
+      // all blocks will only stop when it is "Too Late" or "Invalid Block 
push.
+      // Also see updateStateAndCheckIfPushMore.
       override def shouldRetryError(t: Throwable): Boolean = {
         // If it is a FileNotFoundException originating from the client while 
pushing the shuffle
         // blocks to the server, then we stop pushing all the blocks because 
this indicates the
         // shuffle files are deleted and subsequent block push will also fail.
         if (t.getCause != null && 
t.getCause.isInstanceOf[FileNotFoundException]) {
           return false
         }
-        // If the block is too late, there is no need to retry it
-        
!Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)
+        // If the block is too late or the invalid block push, there is no 
need to retry it
+        !(Throwables.getStackTraceAsString(t)
+          .contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX) ||
+          
Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.STALE_BLOCK_PUSH));

Review comment:
       Add local variable for `Throwables.getStackTraceAsString` and use that 
to avoid repeated recomputation.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java
##########
@@ -89,6 +95,7 @@ public boolean equals(Object other) {
     if (other != null && other instanceof MergeStatuses) {
       MergeStatuses o = (MergeStatuses) other;
       return Objects.equal(shuffleId, o.shuffleId)
+        &&  Objects.equal(shuffleMergeId, o.shuffleMergeId)

Review comment:
       super nit: remove additional whitespace after `&&`

##########
File path: core/src/main/scala/org/apache/spark/Dependency.scala
##########
@@ -122,6 +120,14 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
    */
   private[this] var _shuffleMergedFinalized: Boolean = false
 
+  /**
+   * shuffleMergeId is used to give temporal ordering to the executions of a 
ShuffleDependency.
+   * This is required in order to handle indeterministic stage retries for 
push-based shuffle.
+   */
+  private[this] var _shuffleMergeId: Int = -1

Review comment:
       start with `0`

##########
File path: core/src/main/scala/org/apache/spark/Dependency.scala
##########
@@ -148,6 +154,22 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     }
   }
 
+  def newShuffleMergeState(): Unit = {
+    _shuffleMergeEnabled = canShuffleMergeBeEnabled()
+    _shuffleMergedFinalized = false
+    mergerLocs = Nil
+    _shuffleMergeId = _shuffleMergeId + 1

Review comment:
       super nit: `_shuffleMergeId += 1`

##########
File path: core/src/main/scala/org/apache/spark/Dependency.scala
##########
@@ -148,6 +153,18 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     }
   }
 
+  def resetShuffleMergeState(): Unit = {
+    _shuffleMergeEnabled = canShuffleMergeBeEnabled()
+    _shuffleMergedFinalized = false
+    mergerLocs = Nil

Review comment:
       +CC @Victsm who worked on colocating mergers.

##########
File path: core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
##########
@@ -361,7 +366,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) 
extends Logging {
     for (reduceId <- 0 until numPartitions) {
       val blockSize = partitionLengths(reduceId)
       logDebug(
-        s"Block ${ShufflePushBlockId(shuffleId, partitionId, reduceId)} is of 
size $blockSize")
+        s"Block ${ShufflePushBlockId(shuffleId, partitionId, shuffleMergeId,

Review comment:
       change parameter order : shuffleId, shuffleMergeId, partitionId.
   
   Please check the entire PR very carefully to avoid these kinds of errors; I 
might have missed some.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to