mridulm commented on a change in pull request #32140: URL: https://github.com/apache/spark/pull/32140#discussion_r645277961
########## File path: common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java ########## @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * Request to find the meta information for the specified merged block. The meta information + * contains the number of chunks in the merged blocks and the maps ids in each chunk. + * + * @since 3.2.0 + */ +public class MergedBlockMetaRequest extends AbstractMessage implements RequestMessage { + public final long requestId; + public final String appId; + public final int shuffleId; + public final int reduceId; + + public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) { + super(null, false); + this.requestId = requestId; + this.appId = appId; + this.shuffleId = shuffleId; + this.reduceId = reduceId; + } + + @Override + public Type type() { + return Type.MergedBlockMetaRequest; + } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(appId) + 8; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + Encoders.Strings.encode(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(reduceId); + } + + public static MergedBlockMetaRequest decode(ByteBuf buf) { + long requestId = buf.readLong(); + String appId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int reduceId = buf.readInt(); + return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId); + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, appId, shuffleId, reduceId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof MergedBlockMetaRequest) { + MergedBlockMetaRequest o = (MergedBlockMetaRequest) other; + return requestId == o.requestId && Objects.equal(appId, o.appId) + && shuffleId == o.shuffleId && reduceId == o.reduceId; Review comment: nit: move the appId check to last. ########## File path: common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java ########## @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A basic callback. This is extended by {@link RpcResponseCallback} and + * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests + * can be handled in {@link TransportResponseHandler} a similar way. + * + * @since 3.2.0 + */ +public interface BaseResponseCallback { Review comment: nit: I dont have good suggestions, but any thoughts on renaming this interface better ? Thoughts @Ngone51, @attilapiros ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java ########## @@ -413,6 +466,47 @@ public ManagedBuffer next() { } } + private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> { + + private int reduceIdx = 0; + private int chunkIdx = 0; + + private final String appId; + private final int shuffleId; + private final int[] reduceIds; + private final int[][] chunkIds; + + ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) { + appId = msg.appId; + shuffleId = msg.shuffleId; + reduceIds = msg.reduceIds; + chunkIds = msg.chunkIds; + } + + @Override + public boolean hasNext() { + // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks + // must have non-empty reduceIds and chunkIds, see the checking logic in + // OneForOneBlockFetcher. + assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length); + return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length; + } + + @Override + public ManagedBuffer next() { Review comment: Reviewer note: `Iterator` contract requires that `next` should check if `hasNext` is true - else throw `NoSuchElementException`. Unfortunately, the other iterators in `ExternalBlockHandler` are also not doing it ... ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && + !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { return false; } } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); + } + if (!primaryIdToBlocksInfo.containsKey(primaryId)) { + primaryIdToBlocksInfo.put(primaryId, new BlocksInfo()); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId); + blocksInfoByPrimaryId.blockIds.add(blockId); + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); Review comment: Add a one line note on what `blockIdParts[3]` can be. ########## File path: core/src/main/scala/org/apache/spark/storage/BlockId.scala ########## @@ -124,11 +134,12 @@ class UnrecognizedBlockId(name: String) @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r - val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE = "shuffle_([0-9]+)_(-?[0-9]+)_([0-9]+)".r Review comment: nit: `\\d+` instead ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java ########## @@ -128,24 +134,23 @@ protected void handleMessage( BlockTransferMessage msgObj, TransportClient client, RpcResponseCallback callback) { - if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) { + if (msgObj instanceof AbstractFetchShuffleBlocks || msgObj instanceof OpenBlocks) { final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time(); try { int numBlockIds; long streamId; - if (msgObj instanceof FetchShuffleBlocks) { - FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj; + if (msgObj instanceof AbstractFetchShuffleBlocks) { + AbstractFetchShuffleBlocks msg = (AbstractFetchShuffleBlocks) msgObj; checkAuth(client, msg.appId); - numBlockIds = 0; - if (msg.batchFetchEnabled) { - numBlockIds = msg.mapIds.length; + numBlockIds = ((AbstractFetchShuffleBlocks) msgObj).getNumBlocks(); Review comment: `getNumBlocks` makes this code cleaner. ########## File path: common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java ########## @@ -199,14 +200,31 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + BaseResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.errorString); } else { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } + } else if (message instanceof MergedBlockMetaSuccess) { + MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message; + MergedBlockMetaResponseCallback listener = + (MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn( + "Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not" + + " outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size()); + resp.body().release(); + } else { + outstandingRpcs.remove(resp.requestId); + try { + listener.onSuccess(resp.getNumChunks(), resp.body()); + } finally { + resp.body().release(); Review comment: nit: move `resp.body().release()` to try/finally for this entire else block. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java ########## @@ -189,9 +194,14 @@ protected void handleMessage( } else if (msgObj instanceof GetLocalDirsForExecutors) { GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj; checkAuth(client, msg.appId); - Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds); + String[] execIdsForBlockResolver = Arrays.stream(msg.execIds) + .filter(execId -> !SHUFFLE_MERGER_IDENTIFIER.equals(execId)).toArray(String[]::new); + Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, + execIdsForBlockResolver); + if (Arrays.asList(msg.execIds).contains(SHUFFLE_MERGER_IDENTIFIER)) { + localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId)); + } Review comment: ```suggestion Set<String> execIdsForBlockResolver = Sets.newHashSet(msg.execIds); boolean fetchMergedBlockDirs = execIdsForBlockResolver.remove(SHUFFLE_MERGER_IDENTIFIER); Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, execIdsForBlockResolver); if (fetchMergedBlockDirs) { localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId)); } ``` With a corresponding change in `blockManager.getLocalDirs` to take a set of executor ids instead of array. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); Review comment: Just to clarify, we are not modifying old fetch protocol at all. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && + !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { return false; } } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); + } + if (!primaryIdToBlocksInfo.containsKey(primaryId)) { + primaryIdToBlocksInfo.put(primaryId, new BlocksInfo()); } Review comment: nit: ```suggestion BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId, id -> new BlocksInfo()); ``` and remove the get below ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java ########## @@ -333,14 +382,18 @@ public ShuffleMetrics() { final int[] mapIdAndReduceIds = new int[2 * blockIds.length]; for (int i = 0; i < blockIds.length; i++) { String[] blockIdParts = blockIds[i].split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + if (blockIdParts.length != 4 + || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_PREFIX)) + || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX))) { throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); } if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockIds[i]); } + // For regular blocks this is mapId. For chunks this is reduceId. mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); + // For regular blocks this is reduceId. For chunks this is chunkId. mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); Review comment: Do we want to rename this variable (here and in constructor) and this method given the overloading of map/reduce vs reduce/chunk now ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java ########## @@ -413,6 +466,47 @@ public ManagedBuffer next() { } } + private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> { + + private int reduceIdx = 0; + private int chunkIdx = 0; + + private final String appId; + private final int shuffleId; + private final int[] reduceIds; + private final int[][] chunkIds; + + ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) { + appId = msg.appId; + shuffleId = msg.shuffleId; + reduceIds = msg.reduceIds; + chunkIds = msg.chunkIds; + } + + @Override + public boolean hasNext() { + // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks + // must have non-empty reduceIds and chunkIds, see the checking logic in + // OneForOneBlockFetcher. + assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length); Review comment: Move this assertion into constructor itself. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && + !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { return false; } } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { Review comment: Here we are assuming all the blocks are either chunks or all are blocks. That is not the validation we are performing in `areShuffleBlocksOrChunks` - where a mix of both can pass. Do we want to make it stricter in `areShuffleBlocksOrChunks` ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java ########## @@ -413,6 +466,47 @@ public ManagedBuffer next() { } } + private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> { + + private int reduceIdx = 0; + private int chunkIdx = 0; + + private final String appId; + private final int shuffleId; + private final int[] reduceIds; + private final int[][] chunkIds; + + ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) { + appId = msg.appId; + shuffleId = msg.shuffleId; + reduceIds = msg.reduceIds; + chunkIds = msg.chunkIds; + } + + @Override + public boolean hasNext() { + // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks + // must have non-empty reduceIds and chunkIds, see the checking logic in + // OneForOneBlockFetcher. + assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length); + return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length; + } + + @Override + public ManagedBuffer next() { + ManagedBuffer block = mergeManager.getMergedBlockData( + appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]); + if (chunkIdx < chunkIds[reduceIdx].length - 1) { + chunkIdx += 1; + } else { + chunkIdx = 0; + reduceIdx += 1; + } + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); Review comment: When would `block` be `null` ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && + !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { return false; } } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); + } + if (!primaryIdToBlocksInfo.containsKey(primaryId)) { + primaryIdToBlocksInfo.put(primaryId, new BlocksInfo()); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId); + blocksInfoByPrimaryId.blockIds.add(blockId); + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); } } - long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); - int[][] reduceIdArr = new int[mapIds.length][]; + Set<Number> primaryIds = primaryIdToBlocksInfo.keySet(); + // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks, + // secondaryIds are chunkIds. + int[][] secondaryIdsArray = new int[primaryIds.size()][]; int blockIdIndex = 0; - for (int i = 0; i < mapIds.length; i++) { - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]); - reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds); + int secIndex = 0; + for (Number primaryId : primaryIds) { + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId); + secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids); - // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks - // because the shuffle data's return order should match the `blockIds`'s order to ensure - // blockId and data match. - for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) { - this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j); + // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/ + // FetchShuffleBlockChunks because the shuffle data's return order should match the + // `blockIds`'s order to ensure blockId and data match. + for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) { + this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j); } } assert(blockIdIndex == this.blockIds.length); - - return new FetchShuffleBlocks( - appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled); + if (!areMergedChunks) { + long[] mapIds = Longs.toArray(primaryIds); Review comment: nit: `Longs.toArray` is a bit expensive - same for `Ints.toArray` below. If we can avoid it, while keeping code clean/concise, that would be preferable (there are couple of other locations in this PR which use these api's). ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && + !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { Review comment: super nit: As coded, checking for `SHUFFLE_CHUNK_PREFIX` here is redundant - though I am fine with it for clarity. Btw, we are avoiding a '_' suffix check here. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -246,6 +304,14 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { } } + private void failSingleBlockChunk(String shuffleBlockChunkId, Throwable e) { + try { + listener.onBlockFetchFailure(shuffleBlockChunkId, e); + } catch (Exception e2) { + logger.error("Error from blockFetchFailure callback", e2); + } + } Review comment: We can have `failRemainingBlocks` delegate to `failSingleBlockChunk` now ? ``` private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { Arrays.stream(failedBlockIds).forEach(blockId -> failSingleBlockChunk(blockId, e)); } ``` ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && + !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { return false; } } return true; Review comment: ```suggestion return Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)); ``` Review note: startsWith `SHUFFLE_BLOCK_PREFIX` is superset of startsWith `SHUFFLE_CHUNK_PREFIX` - though I am fine with keeping them separate in interest of clarity. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && + !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { return false; } } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); + } + if (!primaryIdToBlocksInfo.containsKey(primaryId)) { + primaryIdToBlocksInfo.put(primaryId, new BlocksInfo()); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId); + blocksInfoByPrimaryId.blockIds.add(blockId); + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); } } - long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); - int[][] reduceIdArr = new int[mapIds.length][]; + Set<Number> primaryIds = primaryIdToBlocksInfo.keySet(); + // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks, + // secondaryIds are chunkIds. + int[][] secondaryIdsArray = new int[primaryIds.size()][]; int blockIdIndex = 0; - for (int i = 0; i < mapIds.length; i++) { - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]); - reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds); + int secIndex = 0; + for (Number primaryId : primaryIds) { + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId); Review comment: Iterate over `primaryIdToBlocksInfo.entrySet` instead ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -276,9 +342,13 @@ public void onComplete(String streamId) throws IOException { @Override public void onFailure(String streamId, Throwable cause) throws IOException { channel.close(); Review comment: What is the expected behavior if there are exceptions closing channel ? (the failure perhaps being due to `onData` throwing exception, for example) ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && + !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { return false; } } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); + } + if (!primaryIdToBlocksInfo.containsKey(primaryId)) { + primaryIdToBlocksInfo.put(primaryId, new BlocksInfo()); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId); + blocksInfoByPrimaryId.blockIds.add(blockId); + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); Review comment: Update the comment above/add a one line note on what `blockIdParts[4]` can be ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java ########## @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + + +/** + * Request to read a set of block chunks. Returns {@link StreamHandle}. + * + * @since 3.2.0 + */ +public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks { + // The length of reduceIds must equal to chunkIds.size(). Review comment: How strong is this assumption ? Do we see a future evolution where this can break ? Or is it tied to the protocol in nontrivial ways ? As an example - `encode` and `decode` do not assume this currently (we could have avoided writing `chunkIdsLen` if they did) ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +93,124 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && + !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { return false; } } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); + } + if (!primaryIdToBlocksInfo.containsKey(primaryId)) { + primaryIdToBlocksInfo.put(primaryId, new BlocksInfo()); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId); + blocksInfoByPrimaryId.blockIds.add(blockId); + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); } } - long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); - int[][] reduceIdArr = new int[mapIds.length][]; + Set<Number> primaryIds = primaryIdToBlocksInfo.keySet(); + // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks, + // secondaryIds are chunkIds. + int[][] secondaryIdsArray = new int[primaryIds.size()][]; int blockIdIndex = 0; - for (int i = 0; i < mapIds.length; i++) { - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]); - reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds); + int secIndex = 0; + for (Number primaryId : primaryIds) { + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId); + secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids); - // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks - // because the shuffle data's return order should match the `blockIds`'s order to ensure - // blockId and data match. - for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) { - this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j); + // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/ + // FetchShuffleBlockChunks because the shuffle data's return order should match the + // `blockIds`'s order to ensure blockId and data match. + for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) { + this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j); Review comment: ```suggestion for (String blockId : blocksInfoByPrimaryId.blockIds) { this.blockIds[blockIdIndex++] = blockId; } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
