otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649514612
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -436,24 +485,48 @@ final class ShuffleBlockFetcherIterator(
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = Seq.empty[FetchBlockInfo]
-
while (iterator.hasNext) {
val (blockId, size, mapIndex) = iterator.next()
- assertPositiveBlockSize(blockId, size)
curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex))
curRequestSize += size
- // For batch fetch, the actual block in flight should count for merged
block.
- val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >=
maxBlocksInFlightPerAddress
- if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
- curBlocks = createFetchRequests(curBlocks, address, isLast = false,
- collectedRemoteRequests)
- curRequestSize = curBlocks.map(_.size).sum
+ blockId match {
+ // Either all blocks are merged blocks, merged block chunks, or
original non-merged blocks.
+ // Based on these types, we decide to do batch fetch and create
FetchRequests with
+ // forMergedMetas set.
+ case ShuffleBlockChunkId(_, _, _) =>
+ if (curRequestSize >= targetRemoteRequestSize ||
+ curBlocks.size >= maxBlocksInFlightPerAddress) {
+ curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+ collectedRemoteRequests, enableBatchFetch = false)
+ curRequestSize = curBlocks.map(_.size).sum
+ }
+ case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+ if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+ curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+ collectedRemoteRequests, enableBatchFetch = false,
forMergedMetas = true)
+ }
+ case _ =>
+ // For batch fetch, the actual block in flight should count for
merged block.
+ val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >=
maxBlocksInFlightPerAddress
+ if (curRequestSize >= targetRemoteRequestSize ||
mayExceedsMaxBlocks) {
+ curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+ collectedRemoteRequests, enableBatchFetch = doBatchFetch)
+ curRequestSize = curBlocks.map(_.size).sum
+ }
}
}
// Add in the final request
if (curBlocks.nonEmpty) {
+ val (enableBatchFetch, areMergedBlocks) = {
+ curBlocks.head.blockId match {
+ case ShuffleBlockChunkId(_, _, _) => (false, false)
+ case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
+ case _ => (doBatchFetch, false)
+ }
+ }
curBlocks = createFetchRequests(curBlocks, address, isLast = true,
- collectedRemoteRequests)
+ collectedRemoteRequests, enableBatchFetch = enableBatchFetch,
+ forMergedMetas = areMergedBlocks)
curRequestSize = curBlocks.map(_.size).sum
Review comment:
We do want the sum of the sizes of all the blocks in `curBlocks` so I
think the `sum` is needed.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]