mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645277961



##########
File path: 
common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java
##########
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Request to find the meta information for the specified merged block. The 
meta information
+ * contains the number of chunks in the merged blocks and the maps ids in each 
chunk.
+ *
+ * @since 3.2.0
+ */
+public class MergedBlockMetaRequest extends AbstractMessage implements 
RequestMessage {
+  public final long requestId;
+  public final String appId;
+  public final int shuffleId;
+  public final int reduceId;
+
+  public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, 
int reduceId) {
+    super(null, false);
+    this.requestId = requestId;
+    this.appId = appId;
+    this.shuffleId = shuffleId;
+    this.reduceId = reduceId;
+  }
+
+  @Override
+  public Type type() {
+    return Type.MergedBlockMetaRequest;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8 + Encoders.Strings.encodedLength(appId) + 8;
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeLong(requestId);
+    Encoders.Strings.encode(buf, appId);
+    buf.writeInt(shuffleId);
+    buf.writeInt(reduceId);
+  }
+
+  public static MergedBlockMetaRequest decode(ByteBuf buf) {
+    long requestId = buf.readLong();
+    String appId = Encoders.Strings.decode(buf);
+    int shuffleId = buf.readInt();
+    int reduceId = buf.readInt();
+    return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hashCode(requestId, appId, shuffleId, reduceId);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other instanceof MergedBlockMetaRequest) {
+      MergedBlockMetaRequest o = (MergedBlockMetaRequest) other;
+      return requestId == o.requestId && Objects.equal(appId, o.appId)
+        && shuffleId == o.shuffleId && reduceId == o.reduceId;

Review comment:
       nit: move the appId check to last.

##########
File path: 
common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A basic callback. This is extended by {@link RpcResponseCallback} and
+ * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and 
MergedBlockMetaRequests
+ * can be handled in {@link TransportResponseHandler} a similar way.
+ *
+ * @since 3.2.0
+ */
+public interface BaseResponseCallback {

Review comment:
       nit: I dont have good suggestions, but any thoughts on renaming this 
interface better ?
   Thoughts @Ngone51, @attilapiros ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements 
Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in 
FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < 
chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {

Review comment:
       Reviewer note: `Iterator` contract requires that `next` should check if 
`hasNext` is true - else throw `NoSuchElementException`.
   Unfortunately, the other iterators in `ExternalBlockHandler` are also not 
doing it ...

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));

Review comment:
       Add a one line note on what `blockIdParts[3]` can be.

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockId.scala
##########
@@ -124,11 +134,12 @@ class UnrecognizedBlockId(name: String)
 @DeveloperApi
 object BlockId {
   val RDD = "rdd_([0-9]+)_([0-9]+)".r
-  val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE = "shuffle_([0-9]+)_(-?[0-9]+)_([0-9]+)".r

Review comment:
       nit: `\\d+` instead ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -128,24 +134,23 @@ protected void handleMessage(
       BlockTransferMessage msgObj,
       TransportClient client,
       RpcResponseCallback callback) {
-    if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) {
+    if (msgObj instanceof AbstractFetchShuffleBlocks || msgObj instanceof 
OpenBlocks) {
       final Timer.Context responseDelayContext = 
metrics.openBlockRequestLatencyMillis.time();
       try {
         int numBlockIds;
         long streamId;
-        if (msgObj instanceof FetchShuffleBlocks) {
-          FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj;
+        if (msgObj instanceof AbstractFetchShuffleBlocks) {
+          AbstractFetchShuffleBlocks msg = (AbstractFetchShuffleBlocks) msgObj;
           checkAuth(client, msg.appId);
-          numBlockIds = 0;
-          if (msg.batchFetchEnabled) {
-            numBlockIds = msg.mapIds.length;
+          numBlockIds = ((AbstractFetchShuffleBlocks) msgObj).getNumBlocks();

Review comment:
       `getNumBlocks` makes this code cleaner.

##########
File path: 
common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
##########
@@ -199,14 +200,31 @@ public void handle(ResponseMessage message) throws 
Exception {
       }
     } else if (message instanceof RpcFailure) {
       RpcFailure resp = (RpcFailure) message;
-      RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+      BaseResponseCallback listener = outstandingRpcs.get(resp.requestId);
       if (listener == null) {
         logger.warn("Ignoring response for RPC {} from {} ({}) since it is not 
outstanding",
           resp.requestId, getRemoteAddress(channel), resp.errorString);
       } else {
         outstandingRpcs.remove(resp.requestId);
         listener.onFailure(new RuntimeException(resp.errorString));
       }
+    } else if (message instanceof MergedBlockMetaSuccess) {
+      MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message;
+      MergedBlockMetaResponseCallback listener =
+        (MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId);
+      if (listener == null) {
+        logger.warn(
+          "Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) 
since it is not"
+            + " outstanding", resp.requestId, getRemoteAddress(channel), 
resp.body().size());
+        resp.body().release();
+      } else {
+        outstandingRpcs.remove(resp.requestId);
+        try {
+          listener.onSuccess(resp.getNumChunks(), resp.body());
+        } finally {
+          resp.body().release();

Review comment:
       nit: move `resp.body().release()` to try/finally for this entire else 
block.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -189,9 +194,14 @@ protected void handleMessage(
     } else if (msgObj instanceof GetLocalDirsForExecutors) {
       GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj;
       checkAuth(client, msg.appId);
-      Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, 
msg.execIds);
+      String[] execIdsForBlockResolver = Arrays.stream(msg.execIds)
+        .filter(execId -> 
!SHUFFLE_MERGER_IDENTIFIER.equals(execId)).toArray(String[]::new);
+      Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId,
+        execIdsForBlockResolver);
+      if (Arrays.asList(msg.execIds).contains(SHUFFLE_MERGER_IDENTIFIER)) {
+        localDirs.put(SHUFFLE_MERGER_IDENTIFIER, 
mergeManager.getMergedBlockDirs(msg.appId));
+      }

Review comment:
       ```suggestion
         Set<String> execIdsForBlockResolver = Sets.newHashSet(msg.execIds);
         boolean fetchMergedBlockDirs = 
execIdsForBlockResolver.remove(SHUFFLE_MERGER_IDENTIFIER);
         Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, 
execIdsForBlockResolver);
         if (fetchMergedBlockDirs) {
           localDirs.put(SHUFFLE_MERGER_IDENTIFIER, 
mergeManager.getMergedBlockDirs(msg.appId));
         }
   ```
   
   With a corresponding change in `blockManager.getLocalDirs` to take a set of 
executor ids instead of array.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);

Review comment:
       Just to clarify, we are not modifying old fetch protocol at all.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }

Review comment:
       nit:
   
   ```suggestion
         BlocksInfo blocksInfoByPrimaryId = 
primaryIdToBlocksInfo.computeIfAbsent(primaryId, id -> new BlocksInfo());
   ```
   
   and remove the get below

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -333,14 +382,18 @@ public ShuffleMetrics() {
       final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
       for (int i = 0; i < blockIds.length; i++) {
         String[] blockIdParts = blockIds[i].split("_");
-        if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
+        if (blockIdParts.length != 4
+          || (!requestForMergedBlockChunks && 
!blockIdParts[0].equals(SHUFFLE_BLOCK_PREFIX))
+          || (requestForMergedBlockChunks && 
!blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX))) {
           throw new IllegalArgumentException("Unexpected shuffle block id 
format: " + blockIds[i]);
         }
         if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
           throw new IllegalArgumentException("Expected shuffleId=" + shuffleId 
+
             ", got:" + blockIds[i]);
         }
+        // For regular blocks this is mapId. For chunks this is reduceId.
         mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
+        // For regular blocks this is reduceId. For chunks this is chunkId.
         mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);

Review comment:
       Do we want to rename this variable (here and in constructor) and this 
method given the overloading of map/reduce vs reduce/chunk now ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements 
Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in 
FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);

Review comment:
       Move this assertion into constructor itself.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {

Review comment:
       Here we are assuming all the blocks are either chunks or all are blocks.
   That is not the validation we are performing in `areShuffleBlocksOrChunks` - 
where a mix of both can pass.
   
   Do we want to make it stricter in `areShuffleBlocksOrChunks` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements 
Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in 
FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < 
chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {
+      ManagedBuffer block = mergeManager.getMergedBlockData(
+        appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]);
+      if (chunkIdx < chunkIds[reduceIdx].length - 1) {
+        chunkIdx += 1;
+      } else {
+        chunkIdx = 0;
+        reduceIdx += 1;
+      }
+      metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);

Review comment:
       When would `block` be `null` ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse 
reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For 
FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in 
in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s 
order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in 
FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order 
should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);

Review comment:
       nit: `Longs.toArray` is a bit expensive - same for `Ints.toArray` below.
   If we can avoid it, while keeping code clean/concise, that would be 
preferable (there are couple of other locations in this PR which use these 
api's).

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {

Review comment:
       super nit: As coded, checking for `SHUFFLE_CHUNK_PREFIX` here is 
redundant - though I am fine with it for clarity.
   Btw, we are avoiding a '_' suffix check here.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -246,6 +304,14 @@ private void failRemainingBlocks(String[] failedBlockIds, 
Throwable e) {
     }
   }
 
+  private void failSingleBlockChunk(String shuffleBlockChunkId, Throwable e) {
+    try {
+      listener.onBlockFetchFailure(shuffleBlockChunkId, e);
+    } catch (Exception e2) {
+      logger.error("Error from blockFetchFailure callback", e2);
+    }
+  }

Review comment:
       We can have `failRemainingBlocks` delegate to `failSingleBlockChunk` now 
?
   ```
     private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
       Arrays.stream(failedBlockIds).forEach(blockId -> 
failSingleBlockChunk(blockId, e));
     }
   ```

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;

Review comment:
       ```suggestion
     return Arrays.stream(blockIds).anyMatch(blockId -> 
!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && 
!blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
   ```
   
   
   Review note: startsWith `SHUFFLE_BLOCK_PREFIX` is superset of startsWith 
`SHUFFLE_CHUNK_PREFIX` - though I am fine with keeping them separate in 
interest of clarity.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse 
reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For 
FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);

Review comment:
       Iterate over `primaryIdToBlocksInfo.entrySet` instead ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -276,9 +342,13 @@ public void onComplete(String streamId) throws IOException 
{
     @Override
     public void onFailure(String streamId, Throwable cause) throws IOException 
{
       channel.close();

Review comment:
       What is the expected behavior if there are exceptions closing channel ? 
(the failure perhaps being due to `onData` throwing exception, for example)

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse 
reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));

Review comment:
       Update the comment above/add a one line note on what `blockIdParts[4]` 
can be

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
##########
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static 
org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+
+/**
+ * Request to read a set of block chunks. Returns {@link StreamHandle}.
+ *
+ * @since 3.2.0
+ */
+public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
+  // The length of reduceIds must equal to chunkIds.size().

Review comment:
       How strong is this assumption ? Do we see a future evolution where this 
can break ? Or is it tied to the protocol in nontrivial ways ?
   As an example - `encode` and `decode` do not assume this currently (we could 
have avoided writing `chunkIdsLen` if they did)

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse 
reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For 
FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in 
in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s 
order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in 
FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order 
should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);

Review comment:
       ```suggestion
       for (String blockId : blocksInfoByPrimaryId.blockIds) {
           this.blockIds[blockIdIndex++] = blockId;
       }
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to