mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r647851101
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
}
}
- private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+ /**
+ * This is called from initialize and also from the fallback which is
triggered from
+ * [[PushBasedFetchHelper]].
+ */
+ private[this] def partitionBlocksByFetchMode(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
Review comment:
nit: Use either `mutable.LinkedHash*` or import the class and use that
directly ?
We have multiple forms in this PR.
##########
File path:
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the
push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+ private val iterator: ShuffleBlockFetcherIterator,
+ private val shuffleClient: BlockStoreClient,
+ private val blockManager: BlockManager,
+ private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+ SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+ blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+ /** A map for storing merged block shuffle chunk bitmap */
+ private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId,
RoaringBitmap]()
+
+ /**
+ * Returns true if the address is for a push-merged block.
+ */
+ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+ SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+ }
+
+ /**
+ * Returns true if the address is not of executor local or merged local
block. false otherwise.
+ */
+ def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+ (isMergedShuffleBlockAddress(address) && address.host !=
blockManager.blockManagerId.host) ||
+ (!isMergedShuffleBlockAddress(address) && address !=
blockManager.blockManagerId)
+ }
+
+ /**
+ * Returns true if the address if of merged local block. false otherwise.
+ */
+ def isMergedLocal(address: BlockManagerId): Boolean = {
+ isMergedShuffleBlockAddress(address) && address.host ==
blockManager.blockManagerId.host
+ }
+
+ def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+ chunksMetaMap(blockId).getCardinality
+ }
+
+ def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+ chunksMetaMap.remove(blockId)
+ }
+
+ def createChunkBlockInfosFromMetaResponse(
+ shuffleId: Int,
+ reduceId: Int,
+ blockSize: Long,
+ numChunks: Int,
+ bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+ val approxChunkSize = blockSize / numChunks
+ val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+ for (i <- 0 until numChunks) {
+ val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+ chunksMetaMap.put(blockChunkId, bitmaps(i))
+ logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+ blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+ }
+ blocksToFetch
+ }
+
+ def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+ val sizeMap = req.blocks.map {
+ case FetchBlockInfo(blockId, size, _) =>
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
Review comment:
nit: Move `}.toMap` to next line
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
}
}
- private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+ /**
+ * This is called from initialize and also from the fallback which is
triggered from
+ * [[PushBasedFetchHelper]].
+ */
+ private[this] def partitionBlocksByFetchMode(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId,
Seq[(BlockId, Long, Int)]],
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]):
ArrayBuffer[FetchRequest] = {
logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+ s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress:
$maxBlocksInFlightPerAddress")
- // Partition to local, host-local and remote blocks. Remote blocks are
further split into
- // FetchRequests of size at most maxBytesInFlight in order to limit the
amount of data in flight
+ // Partition to local, host-local, merged-local, remote (includes
merged-remote) blocks.
+ // Remote blocks are further split into FetchRequests of size at most
maxBytesInFlight in order
+ // to limit the amount of data in flight
val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+ val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId,
Int)]()
var localBlockBytes = 0L
var hostLocalBlockBytes = 0L
+ var mergedLocalBlockBytes = 0L
var remoteBlockBytes = 0L
+ val prevNumBlocksToFetch = numBlocksToFetch
val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
for ((address, blockInfos) <- blocksByAddress) {
- if (Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
- checkBlockSizes(blockInfos)
+ checkBlockSizes(blockInfos)
+ if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+ // These are push-based merged blocks or chunks of these merged blocks.
+ if (address.host == blockManager.blockManagerId.host) {
+ val pushMergedBlockInfos = blockInfos.map(
+ info => FetchBlockInfo(info._1, info._2, info._3))
+ numBlocksToFetch += pushMergedBlockInfos.size
+ mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+ val size = pushMergedBlockInfos.map(_.size).sum
+ logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+ s"of size $size")
+ mergedLocalBlockBytes += size
+ } else {
+ remoteBlockBytes += blockInfos.map(_._2).sum
+ collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+ }
+ } else if (
+ Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
localBlocks ++= mergedBlockInfos.map(info => (info.blockId,
info.mapIndex))
localBlockBytes += mergedBlockInfos.map(_.size).sum
} else if (blockManager.hostLocalDirManager.isDefined &&
address.host == blockManager.blockManagerId.host) {
- checkBlockSizes(blockInfos)
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
val blocksForAddress =
mergedBlockInfos.map(info => (info.blockId, info.size,
info.mapIndex))
hostLocalBlocksByExecutor += address -> blocksForAddress
- hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+ hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info =>
(info._1, info._3))
hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
} else {
remoteBlockBytes += blockInfos.map(_._2).sum
collectFetchRequests(address, blockInfos, collectedRemoteRequests)
}
}
val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
- val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
- assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size +
numRemoteBlocks,
- s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the
number of local " +
- s"blocks ${localBlocks.size} + the number of host-local blocks
${hostLocalBlocks.size} " +
- s"+ the number of remote blocks ${numRemoteBlocks}.")
- logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)})
non-empty blocks " +
- s"including ${localBlocks.size}
(${Utils.bytesToString(localBlockBytes)}) local and " +
- s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)})
" +
- s"host-local and $numRemoteBlocks
(${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+ val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+ mergedLocalBlockBytes
+ val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+ assert(blocksToFetchCurrentIteration == localBlocks.size +
+ hostLocalBlocksCurrentIteration.size + numRemoteBlocks +
mergedLocalBlocks.size,
+ s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't
equal to " +
+ s"the number of local blocks ${localBlocks.size} + " +
+ s"the number of host-local blocks
${hostLocalBlocksCurrentIteration.size} " +
+ s"the number of merged-local blocks ${mergedLocalBlocks.size} " +
+ s"+ the number of remote blocks ${numRemoteBlocks} ")
+ logInfo(s"Getting $blocksToFetchCurrentIteration " +
+ s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+ s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local
and " +
+ s"${hostLocalBlocksCurrentIteration.size}
(${Utils.bytesToString(hostLocalBlockBytes)}) " +
+ s"host-local and ${mergedLocalBlocks.size}
(${Utils.bytesToString(mergedLocalBlockBytes)}) " +
+ s"local merged and $numRemoteBlocks
(${Utils.bytesToString(remoteBlockBytes)}) " +
+ s"remote blocks")
+ if (hostLocalBlocksCurrentIteration.nonEmpty) {
Review comment:
super nit: remove the `nonEmpty` check.
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
}
}
- private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+ /**
+ * This is called from initialize and also from the fallback which is
triggered from
+ * [[PushBasedFetchHelper]].
+ */
+ private[this] def partitionBlocksByFetchMode(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId,
Seq[(BlockId, Long, Int)]],
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]):
ArrayBuffer[FetchRequest] = {
logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+ s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress:
$maxBlocksInFlightPerAddress")
- // Partition to local, host-local and remote blocks. Remote blocks are
further split into
- // FetchRequests of size at most maxBytesInFlight in order to limit the
amount of data in flight
+ // Partition to local, host-local, merged-local, remote (includes
merged-remote) blocks.
+ // Remote blocks are further split into FetchRequests of size at most
maxBytesInFlight in order
+ // to limit the amount of data in flight
val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+ val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId,
Int)]()
var localBlockBytes = 0L
var hostLocalBlockBytes = 0L
+ var mergedLocalBlockBytes = 0L
var remoteBlockBytes = 0L
+ val prevNumBlocksToFetch = numBlocksToFetch
val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
for ((address, blockInfos) <- blocksByAddress) {
- if (Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
- checkBlockSizes(blockInfos)
+ checkBlockSizes(blockInfos)
+ if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+ // These are push-based merged blocks or chunks of these merged blocks.
+ if (address.host == blockManager.blockManagerId.host) {
+ val pushMergedBlockInfos = blockInfos.map(
+ info => FetchBlockInfo(info._1, info._2, info._3))
+ numBlocksToFetch += pushMergedBlockInfos.size
+ mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+ val size = pushMergedBlockInfos.map(_.size).sum
+ logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+ s"of size $size")
+ mergedLocalBlockBytes += size
+ } else {
+ remoteBlockBytes += blockInfos.map(_._2).sum
+ collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+ }
+ } else if (
+ Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
localBlocks ++= mergedBlockInfos.map(info => (info.blockId,
info.mapIndex))
localBlockBytes += mergedBlockInfos.map(_.size).sum
} else if (blockManager.hostLocalDirManager.isDefined &&
address.host == blockManager.blockManagerId.host) {
- checkBlockSizes(blockInfos)
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
val blocksForAddress =
mergedBlockInfos.map(info => (info.blockId, info.size,
info.mapIndex))
hostLocalBlocksByExecutor += address -> blocksForAddress
- hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+ hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info =>
(info._1, info._3))
hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
} else {
remoteBlockBytes += blockInfos.map(_._2).sum
collectFetchRequests(address, blockInfos, collectedRemoteRequests)
}
}
val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
- val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
- assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size +
numRemoteBlocks,
- s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the
number of local " +
- s"blocks ${localBlocks.size} + the number of host-local blocks
${hostLocalBlocks.size} " +
- s"+ the number of remote blocks ${numRemoteBlocks}.")
- logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)})
non-empty blocks " +
- s"including ${localBlocks.size}
(${Utils.bytesToString(localBlockBytes)}) local and " +
- s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)})
" +
- s"host-local and $numRemoteBlocks
(${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+ val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+ mergedLocalBlockBytes
+ val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+ assert(blocksToFetchCurrentIteration == localBlocks.size +
+ hostLocalBlocksCurrentIteration.size + numRemoteBlocks +
mergedLocalBlocks.size,
+ s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't
equal to " +
+ s"the number of local blocks ${localBlocks.size} + " +
+ s"the number of host-local blocks
${hostLocalBlocksCurrentIteration.size} " +
+ s"the number of merged-local blocks ${mergedLocalBlocks.size} " +
+ s"+ the number of remote blocks ${numRemoteBlocks} ")
+ logInfo(s"Getting $blocksToFetchCurrentIteration " +
+ s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+ s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local
and " +
+ s"${hostLocalBlocksCurrentIteration.size}
(${Utils.bytesToString(hostLocalBlockBytes)}) " +
+ s"host-local and ${mergedLocalBlocks.size}
(${Utils.bytesToString(mergedLocalBlockBytes)}) " +
+ s"local merged and $numRemoteBlocks
(${Utils.bytesToString(remoteBlockBytes)}) " +
+ s"remote blocks")
+ if (hostLocalBlocksCurrentIteration.nonEmpty) {
+ this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration
+ }
collectedRemoteRequests
}
private def createFetchRequest(
blocks: Seq[FetchBlockInfo],
- address: BlockManagerId): FetchRequest = {
+ address: BlockManagerId,
+ forMergedMetas: Boolean = false): FetchRequest = {
Review comment:
Remove the default value for `forMergedMetas` ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the
push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+ private val iterator: ShuffleBlockFetcherIterator,
+ private val shuffleClient: BlockStoreClient,
+ private val blockManager: BlockManager,
+ private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+ SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+ blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+ /** A map for storing merged block shuffle chunk bitmap */
+ private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId,
RoaringBitmap]()
+
+ /**
+ * Returns true if the address is for a push-merged block.
+ */
+ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+ SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+ }
+
+ /**
+ * Returns true if the address is not of executor local or merged local
block. false otherwise.
+ */
+ def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+ (isMergedShuffleBlockAddress(address) && address.host !=
blockManager.blockManagerId.host) ||
+ (!isMergedShuffleBlockAddress(address) && address !=
blockManager.blockManagerId)
+ }
+
+ /**
+ * Returns true if the address if of merged local block. false otherwise.
+ */
+ def isMergedLocal(address: BlockManagerId): Boolean = {
+ isMergedShuffleBlockAddress(address) && address.host ==
blockManager.blockManagerId.host
+ }
+
+ def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+ chunksMetaMap(blockId).getCardinality
+ }
+
+ def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+ chunksMetaMap.remove(blockId)
+ }
+
+ def createChunkBlockInfosFromMetaResponse(
+ shuffleId: Int,
+ reduceId: Int,
+ blockSize: Long,
+ numChunks: Int,
+ bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+ val approxChunkSize = blockSize / numChunks
+ val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+ for (i <- 0 until numChunks) {
+ val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+ chunksMetaMap.put(blockChunkId, bitmaps(i))
+ logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+ blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+ }
+ blocksToFetch
+ }
+
+ def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+ val sizeMap = req.blocks.map {
+ case FetchBlockInfo(blockId, size, _) =>
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+ val address = req.address
+ val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+ override def onSuccess(shuffleId: Int, reduceId: Int, meta:
MergedBlockMeta): Unit = {
+ logInfo(s"Received the meta of merged block for ($shuffleId,
$reduceId) " +
+ s"from ${req.address.host}:${req.address.port}")
+ try {
+ iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId,
reduceId,
+ sizeMap((shuffleId, reduceId)), meta.getNumChunks,
meta.readChunkBitmaps(), address))
+ } catch {
+ case exception: Throwable =>
Review comment:
Why catch `Throwable` ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the
push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+ private val iterator: ShuffleBlockFetcherIterator,
+ private val shuffleClient: BlockStoreClient,
+ private val blockManager: BlockManager,
+ private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+ SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+ blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+ /** A map for storing merged block shuffle chunk bitmap */
+ private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId,
RoaringBitmap]()
Review comment:
There can be concurrent mods to this Map, handle MT-safety ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
}
}
- private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+ /**
+ * This is called from initialize and also from the fallback which is
triggered from
+ * [[PushBasedFetchHelper]].
+ */
+ private[this] def partitionBlocksByFetchMode(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId,
Seq[(BlockId, Long, Int)]],
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]):
ArrayBuffer[FetchRequest] = {
logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+ s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress:
$maxBlocksInFlightPerAddress")
- // Partition to local, host-local and remote blocks. Remote blocks are
further split into
- // FetchRequests of size at most maxBytesInFlight in order to limit the
amount of data in flight
+ // Partition to local, host-local, merged-local, remote (includes
merged-remote) blocks.
+ // Remote blocks are further split into FetchRequests of size at most
maxBytesInFlight in order
+ // to limit the amount of data in flight
val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+ val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId,
Int)]()
var localBlockBytes = 0L
var hostLocalBlockBytes = 0L
+ var mergedLocalBlockBytes = 0L
var remoteBlockBytes = 0L
+ val prevNumBlocksToFetch = numBlocksToFetch
val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
for ((address, blockInfos) <- blocksByAddress) {
- if (Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
+ if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+ // These are push-based merged blocks or chunks of these merged blocks.
+ if (address.host == blockManager.blockManagerId.host) {
+ checkBlockSizes(blockInfos)
+ val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+ blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch = false)
Review comment:
For merged blocks, why are we doing this ?
Currently, this is a noop anyway.
We can remove `pushMergedBlockInfos` entirely here.
##########
File path:
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the
push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+ private val iterator: ShuffleBlockFetcherIterator,
+ private val shuffleClient: BlockStoreClient,
+ private val blockManager: BlockManager,
+ private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+ SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+ blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+ /** A map for storing merged block shuffle chunk bitmap */
+ private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId,
RoaringBitmap]()
+
+ /**
+ * Returns true if the address is for a push-merged block.
+ */
+ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+ SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
Review comment:
nit: use `==`
##########
File path:
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the
push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+ private val iterator: ShuffleBlockFetcherIterator,
+ private val shuffleClient: BlockStoreClient,
+ private val blockManager: BlockManager,
+ private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+ SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+ blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+ /** A map for storing merged block shuffle chunk bitmap */
+ private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId,
RoaringBitmap]()
+
+ /**
+ * Returns true if the address is for a push-merged block.
+ */
+ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+ SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+ }
+
+ /**
+ * Returns true if the address is not of executor local or merged local
block. false otherwise.
+ */
+ def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+ (isMergedShuffleBlockAddress(address) && address.host !=
blockManager.blockManagerId.host) ||
+ (!isMergedShuffleBlockAddress(address) && address !=
blockManager.blockManagerId)
+ }
+
+ /**
+ * Returns true if the address if of merged local block. false otherwise.
+ */
+ def isMergedLocal(address: BlockManagerId): Boolean = {
+ isMergedShuffleBlockAddress(address) && address.host ==
blockManager.blockManagerId.host
+ }
+
+ def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+ chunksMetaMap(blockId).getCardinality
+ }
+
+ def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+ chunksMetaMap.remove(blockId)
+ }
+
+ def createChunkBlockInfosFromMetaResponse(
+ shuffleId: Int,
+ reduceId: Int,
+ blockSize: Long,
+ numChunks: Int,
+ bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+ val approxChunkSize = blockSize / numChunks
+ val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+ for (i <- 0 until numChunks) {
+ val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+ chunksMetaMap.put(blockChunkId, bitmaps(i))
+ logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+ blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+ }
+ blocksToFetch
+ }
+
+ def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+ val sizeMap = req.blocks.map {
+ case FetchBlockInfo(blockId, size, _) =>
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+ val address = req.address
+ val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+ override def onSuccess(shuffleId: Int, reduceId: Int, meta:
MergedBlockMeta): Unit = {
+ logInfo(s"Received the meta of merged block for ($shuffleId,
$reduceId) " +
+ s"from ${req.address.host}:${req.address.port}")
+ try {
+ iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId,
reduceId,
+ sizeMap((shuffleId, reduceId)), meta.getNumChunks,
meta.readChunkBitmaps(), address))
+ } catch {
+ case exception: Throwable =>
+ logError(s"Failed to parse the meta of merged block for
($shuffleId, $reduceId) " +
+ s"from ${req.address.host}:${req.address.port}", exception)
+ iterator.addToResultsQueue(
+ MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+ }
+ }
+
+ override def onFailure(shuffleId: Int, reduceId: Int, exception:
Throwable): Unit = {
+ logError(s"Failed to get the meta of merged block for ($shuffleId,
$reduceId) " +
+ s"from ${req.address.host}:${req.address.port}", exception)
+
iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId,
reduceId, address))
+ }
+ }
+ req.blocks.foreach { block =>
+ val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+ shuffleClient.getMergedBlockMeta(address.host, address.port,
shuffleBlockId.shuffleId,
+ shuffleBlockId.reduceId, mergedBlocksMetaListener)
+ }
+ }
+
+ // Fetch all outstanding merged local blocks
+ def fetchAllMergedLocalBlocks(
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+ if (mergedLocalBlocks.nonEmpty) {
+ blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_,
mergedLocalBlocks))
+ }
+ }
+
+ /**
+ * Fetch the merged blocks dirs if they are not in the cache and eventually
fetch merged local
+ * blocks.
+ */
+ private def fetchMergedLocalBlocks(
+ hostLocalDirManager: HostLocalDirManager,
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+ val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+ SHUFFLE_MERGER_IDENTIFIER)
+ if (cachedMergerDirs.isDefined) {
+ logDebug(s"Fetching local merged blocks with cached executors dir: " +
+ s"${cachedMergerDirs.get.mkString(", ")}")
+ mergedLocalBlocks.foreach(blockId =>
+ fetchMergedLocalBlock(blockId, cachedMergerDirs.get,
localShuffleMergerBlockMgrId))
+ } else {
+ logDebug(s"Asynchronous fetching local merged blocks without cached
executors dir")
+ hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+ localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+ case Success(dirs) =>
+ mergedLocalBlocks.takeWhile {
+ blockId =>
+ logDebug(s"Successfully fetched local dirs: " +
+ s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+ fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+ localShuffleMergerBlockMgrId)
+ }
+ logDebug(s"Got local merged blocks (without cached executors' dir)
in " +
+ s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() -
startTimeNs)} ms")
+ case Failure(throwable) =>
+ // If we see an exception with getting the local dirs for local
merged blocks,
+ // we fallback to fetch the original unmerged blocks. We do not
report block fetch
+ // failure.
+ logWarning(s"Error occurred while getting the local dirs for local
merged " +
+ s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original
blocks instead",
+ throwable)
+ mergedLocalBlocks.foreach(
+ blockId => iterator.addToResultsQueue(
+ IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0,
isNetworkReqDone = false))
+ )
+ }
+ }
+ }
+
+ /**
+ * Fetch a single local merged block generated.
+ * @param blockId ShuffleBlockId to be fetched
+ * @param localDirs Local directories where the merged shuffle files are
stored
+ * @param blockManagerId BlockManagerId
+ * @return Boolean represents successful or failed fetch
+ */
+ private[this] def fetchMergedLocalBlock(
+ blockId: BlockId,
+ localDirs: Array[String],
+ blockManagerId: BlockManagerId): Boolean = {
+ try {
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId,
localDirs)
+ .readChunkBitmaps()
+ // Fetch local merged shuffle block data as multiple chunks
+ val bufs: Seq[ManagedBuffer] =
blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+ // Update total number of blocks to fetch, reflecting the multiple local
chunks
+ iterator.foundMoreBlocksToFetch(bufs.size - 1)
+ for (chunkId <- bufs.indices) {
+ val buf = bufs(chunkId)
+ buf.retain()
+ val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+ shuffleBlockId.reduceId, chunkId)
+ iterator.addToResultsQueue(
+ SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
blockManagerId, buf.size(), buf,
+ isNetworkReqDone = false))
+ chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+ }
+ true
+ } catch {
+ case e: Exception =>
+ // If we see an exception with reading a local merged block, we
fallback to
+ // fetch the original unmerged blocks. We do not report block fetch
failure
+ // and will continue with the remaining local block read.
+ logWarning(s"Error occurred while fetching local merged block, " +
+ s"prepare to fetch the original blocks", e)
+ iterator.addToResultsQueue(
+ IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone =
false))
+ false
+ }
+ }
+
+ /**
+ * Initiate fetching fallback blocks for a merged block (or a merged block
chunk) that's failed
+ * to fetch.
+ * It calls out to map output tracker to get the list of original blocks for
the
+ * given merged blocks, split them into remote and local blocks, and process
them
+ * accordingly.
+ * The fallback happens when:
+ * 1. There is an exception while creating shuffle block chunk from local
merged shuffle block.
+ * See fetchLocalBlock.
+ * 2. There is a failure when fetching remote shuffle block chunks.
+ * 3. There is a failure when processing SuccessFetchResult which is for a
shuffle chunk
+ * (local or remote).
+ *
+ * @return number of blocks processed
+ */
+ def initiateFallbackBlockFetchForMergedBlock(
+ blockId: BlockId,
+ address: BlockManagerId): Int = {
+ logWarning(s"Falling back to fetch the original unmerged blocks for merged
block $blockId")
+ // Increase the blocks processed since we will process another block in
the next iteration of
+ // the while loop in ShuffleBlockFetcherIterator.next().
+ var blocksProcessed = 1
+ val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long,
Int)])] =
+ if (blockId.isShuffle) {
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ mapOutputTracker.getMapSizesForMergeResult(
+ shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+ } else {
+ val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+ val chunkBitmap: RoaringBitmap =
chunksMetaMap.remove(shuffleChunkId).orNull
+ // When there is a failure to fetch a remote merged shuffle block
chunk, then we try to
+ // fallback not only for that particular remote shuffle block chunk
but also for all the
+ // pending block chunks that belong to the same host. The reason for
doing so is that it is
+ // very likely that the subsequent requests for merged block chunks
from this host will fail
+ // as well. Since, push-based shuffle is best effort and we try not to
increase the delay
+ // of the fetches, we immediately fallback for all the pending shuffle
chunks in the
+ // fetchRequests queue.
+ if (isNotExecutorOrMergedLocal(address)) {
+ // Fallback for all the pending fetch requests
+ val pendingShuffleChunks =
iterator.removePendingChunks(shuffleChunkId, address)
+ if (pendingShuffleChunks.nonEmpty) {
+ pendingShuffleChunks.foreach { pendingBlockId =>
+ logWarning(s"Falling back immediately for merged block
$pendingBlockId")
Review comment:
nit: `logInfo` here ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the
push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+ private val iterator: ShuffleBlockFetcherIterator,
+ private val shuffleClient: BlockStoreClient,
+ private val blockManager: BlockManager,
+ private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+ SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+ blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+ /** A map for storing merged block shuffle chunk bitmap */
+ private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId,
RoaringBitmap]()
+
+ /**
+ * Returns true if the address is for a push-merged block.
+ */
+ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+ SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+ }
+
+ /**
+ * Returns true if the address is not of executor local or merged local
block. false otherwise.
+ */
+ def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+ (isMergedShuffleBlockAddress(address) && address.host !=
blockManager.blockManagerId.host) ||
+ (!isMergedShuffleBlockAddress(address) && address !=
blockManager.blockManagerId)
Review comment:
Do this (and caller tree) support SPARK-27651 ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
}
}
- private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+ /**
+ * This is called from initialize and also from the fallback which is
triggered from
+ * [[PushBasedFetchHelper]].
+ */
+ private[this] def partitionBlocksByFetchMode(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId,
Seq[(BlockId, Long, Int)]],
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]):
ArrayBuffer[FetchRequest] = {
logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+ s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress:
$maxBlocksInFlightPerAddress")
- // Partition to local, host-local and remote blocks. Remote blocks are
further split into
- // FetchRequests of size at most maxBytesInFlight in order to limit the
amount of data in flight
+ // Partition to local, host-local, merged-local, remote (includes
merged-remote) blocks.
+ // Remote blocks are further split into FetchRequests of size at most
maxBytesInFlight in order
+ // to limit the amount of data in flight
val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+ val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId,
Int)]()
var localBlockBytes = 0L
var hostLocalBlockBytes = 0L
+ var mergedLocalBlockBytes = 0L
var remoteBlockBytes = 0L
+ val prevNumBlocksToFetch = numBlocksToFetch
val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
for ((address, blockInfos) <- blocksByAddress) {
- if (Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
- checkBlockSizes(blockInfos)
Review comment:
Now `checkBlockSizes` is being done for all the cases ... while earlier,
it was not done for the last `else`.
Did you look into whether this is ok ?
+CC @attilapiros who did this change initially.
##########
File path:
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the
push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+ private val iterator: ShuffleBlockFetcherIterator,
+ private val shuffleClient: BlockStoreClient,
+ private val blockManager: BlockManager,
+ private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+ SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+ blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+ /** A map for storing merged block shuffle chunk bitmap */
+ private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId,
RoaringBitmap]()
+
+ /**
+ * Returns true if the address is for a push-merged block.
+ */
+ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+ SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+ }
+
+ /**
+ * Returns true if the address is not of executor local or merged local
block. false otherwise.
+ */
+ def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+ (isMergedShuffleBlockAddress(address) && address.host !=
blockManager.blockManagerId.host) ||
+ (!isMergedShuffleBlockAddress(address) && address !=
blockManager.blockManagerId)
+ }
+
+ /**
+ * Returns true if the address if of merged local block. false otherwise.
+ */
+ def isMergedLocal(address: BlockManagerId): Boolean = {
+ isMergedShuffleBlockAddress(address) && address.host ==
blockManager.blockManagerId.host
+ }
+
+ def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+ chunksMetaMap(blockId).getCardinality
+ }
+
+ def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+ chunksMetaMap.remove(blockId)
+ }
+
+ def createChunkBlockInfosFromMetaResponse(
+ shuffleId: Int,
+ reduceId: Int,
+ blockSize: Long,
+ numChunks: Int,
+ bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+ val approxChunkSize = blockSize / numChunks
+ val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+ for (i <- 0 until numChunks) {
+ val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+ chunksMetaMap.put(blockChunkId, bitmaps(i))
+ logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+ blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+ }
+ blocksToFetch
+ }
+
+ def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+ val sizeMap = req.blocks.map {
+ case FetchBlockInfo(blockId, size, _) =>
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+ val address = req.address
+ val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+ override def onSuccess(shuffleId: Int, reduceId: Int, meta:
MergedBlockMeta): Unit = {
+ logInfo(s"Received the meta of merged block for ($shuffleId,
$reduceId) " +
+ s"from ${req.address.host}:${req.address.port}")
+ try {
+ iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId,
reduceId,
+ sizeMap((shuffleId, reduceId)), meta.getNumChunks,
meta.readChunkBitmaps(), address))
+ } catch {
+ case exception: Throwable =>
+ logError(s"Failed to parse the meta of merged block for
($shuffleId, $reduceId) " +
+ s"from ${req.address.host}:${req.address.port}", exception)
+ iterator.addToResultsQueue(
+ MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+ }
+ }
+
+ override def onFailure(shuffleId: Int, reduceId: Int, exception:
Throwable): Unit = {
+ logError(s"Failed to get the meta of merged block for ($shuffleId,
$reduceId) " +
+ s"from ${req.address.host}:${req.address.port}", exception)
+
iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId,
reduceId, address))
+ }
+ }
+ req.blocks.foreach { block =>
+ val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+ shuffleClient.getMergedBlockMeta(address.host, address.port,
shuffleBlockId.shuffleId,
+ shuffleBlockId.reduceId, mergedBlocksMetaListener)
+ }
+ }
+
+ // Fetch all outstanding merged local blocks
+ def fetchAllMergedLocalBlocks(
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+ if (mergedLocalBlocks.nonEmpty) {
+ blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_,
mergedLocalBlocks))
+ }
+ }
+
+ /**
+ * Fetch the merged blocks dirs if they are not in the cache and eventually
fetch merged local
+ * blocks.
+ */
+ private def fetchMergedLocalBlocks(
+ hostLocalDirManager: HostLocalDirManager,
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+ val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+ SHUFFLE_MERGER_IDENTIFIER)
+ if (cachedMergerDirs.isDefined) {
+ logDebug(s"Fetching local merged blocks with cached executors dir: " +
+ s"${cachedMergerDirs.get.mkString(", ")}")
+ mergedLocalBlocks.foreach(blockId =>
+ fetchMergedLocalBlock(blockId, cachedMergerDirs.get,
localShuffleMergerBlockMgrId))
+ } else {
+ logDebug(s"Asynchronous fetching local merged blocks without cached
executors dir")
+ hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+ localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+ case Success(dirs) =>
+ mergedLocalBlocks.takeWhile {
+ blockId =>
+ logDebug(s"Successfully fetched local dirs: " +
+ s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+ fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+ localShuffleMergerBlockMgrId)
+ }
+ logDebug(s"Got local merged blocks (without cached executors' dir)
in " +
+ s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() -
startTimeNs)} ms")
+ case Failure(throwable) =>
+ // If we see an exception with getting the local dirs for local
merged blocks,
+ // we fallback to fetch the original unmerged blocks. We do not
report block fetch
+ // failure.
+ logWarning(s"Error occurred while getting the local dirs for local
merged " +
+ s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original
blocks instead",
+ throwable)
+ mergedLocalBlocks.foreach(
+ blockId => iterator.addToResultsQueue(
+ IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0,
isNetworkReqDone = false))
+ )
+ }
+ }
+ }
+
+ /**
+ * Fetch a single local merged block generated.
+ * @param blockId ShuffleBlockId to be fetched
+ * @param localDirs Local directories where the merged shuffle files are
stored
+ * @param blockManagerId BlockManagerId
+ * @return Boolean represents successful or failed fetch
+ */
+ private[this] def fetchMergedLocalBlock(
+ blockId: BlockId,
+ localDirs: Array[String],
+ blockManagerId: BlockManagerId): Boolean = {
+ try {
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId,
localDirs)
+ .readChunkBitmaps()
+ // Fetch local merged shuffle block data as multiple chunks
+ val bufs: Seq[ManagedBuffer] =
blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+ // Update total number of blocks to fetch, reflecting the multiple local
chunks
+ iterator.foundMoreBlocksToFetch(bufs.size - 1)
+ for (chunkId <- bufs.indices) {
+ val buf = bufs(chunkId)
+ buf.retain()
+ val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+ shuffleBlockId.reduceId, chunkId)
+ iterator.addToResultsQueue(
+ SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
blockManagerId, buf.size(), buf,
+ isNetworkReqDone = false))
+ chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+ }
+ true
+ } catch {
+ case e: Exception =>
+ // If we see an exception with reading a local merged block, we
fallback to
+ // fetch the original unmerged blocks. We do not report block fetch
failure
+ // and will continue with the remaining local block read.
+ logWarning(s"Error occurred while fetching local merged block, " +
+ s"prepare to fetch the original blocks", e)
+ iterator.addToResultsQueue(
+ IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone =
false))
+ false
+ }
+ }
+
+ /**
+ * Initiate fetching fallback blocks for a merged block (or a merged block
chunk) that's failed
+ * to fetch.
+ * It calls out to map output tracker to get the list of original blocks for
the
+ * given merged blocks, split them into remote and local blocks, and process
them
+ * accordingly.
+ * The fallback happens when:
+ * 1. There is an exception while creating shuffle block chunk from local
merged shuffle block.
+ * See fetchLocalBlock.
+ * 2. There is a failure when fetching remote shuffle block chunks.
+ * 3. There is a failure when processing SuccessFetchResult which is for a
shuffle chunk
+ * (local or remote).
+ *
+ * @return number of blocks processed
+ */
+ def initiateFallbackBlockFetchForMergedBlock(
+ blockId: BlockId,
+ address: BlockManagerId): Int = {
+ logWarning(s"Falling back to fetch the original unmerged blocks for merged
block $blockId")
+ // Increase the blocks processed since we will process another block in
the next iteration of
+ // the while loop in ShuffleBlockFetcherIterator.next().
+ var blocksProcessed = 1
+ val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long,
Int)])] =
+ if (blockId.isShuffle) {
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ mapOutputTracker.getMapSizesForMergeResult(
+ shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+ } else {
+ val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+ val chunkBitmap: RoaringBitmap =
chunksMetaMap.remove(shuffleChunkId).orNull
+ // When there is a failure to fetch a remote merged shuffle block
chunk, then we try to
+ // fallback not only for that particular remote shuffle block chunk
but also for all the
+ // pending block chunks that belong to the same host. The reason for
doing so is that it is
+ // very likely that the subsequent requests for merged block chunks
from this host will fail
+ // as well. Since, push-based shuffle is best effort and we try not to
increase the delay
+ // of the fetches, we immediately fallback for all the pending shuffle
chunks in the
+ // fetchRequests queue.
+ if (isNotExecutorOrMergedLocal(address)) {
+ // Fallback for all the pending fetch requests
+ val pendingShuffleChunks =
iterator.removePendingChunks(shuffleChunkId, address)
+ if (pendingShuffleChunks.nonEmpty) {
+ pendingShuffleChunks.foreach { pendingBlockId =>
+ logWarning(s"Falling back immediately for merged block
$pendingBlockId")
+ val bitmapOfPendingChunk: RoaringBitmap =
+ chunksMetaMap.remove(pendingBlockId).orNull
+ assert(bitmapOfPendingChunk != null)
+ chunkBitmap.or(bitmapOfPendingChunk)
Review comment:
Can we have NPE here ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
}
}
- private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+ /**
+ * This is called from initialize and also from the fallback which is
triggered from
+ * [[PushBasedFetchHelper]].
+ */
+ private[this] def partitionBlocksByFetchMode(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId,
Seq[(BlockId, Long, Int)]],
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]):
ArrayBuffer[FetchRequest] = {
logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+ s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress:
$maxBlocksInFlightPerAddress")
- // Partition to local, host-local and remote blocks. Remote blocks are
further split into
- // FetchRequests of size at most maxBytesInFlight in order to limit the
amount of data in flight
+ // Partition to local, host-local, merged-local, remote (includes
merged-remote) blocks.
+ // Remote blocks are further split into FetchRequests of size at most
maxBytesInFlight in order
+ // to limit the amount of data in flight
val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+ val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId,
Int)]()
var localBlockBytes = 0L
var hostLocalBlockBytes = 0L
+ var mergedLocalBlockBytes = 0L
var remoteBlockBytes = 0L
+ val prevNumBlocksToFetch = numBlocksToFetch
val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
for ((address, blockInfos) <- blocksByAddress) {
- if (Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
- checkBlockSizes(blockInfos)
+ checkBlockSizes(blockInfos)
+ if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+ // These are push-based merged blocks or chunks of these merged blocks.
+ if (address.host == blockManager.blockManagerId.host) {
+ val pushMergedBlockInfos = blockInfos.map(
+ info => FetchBlockInfo(info._1, info._2, info._3))
+ numBlocksToFetch += pushMergedBlockInfos.size
+ mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+ val size = pushMergedBlockInfos.map(_.size).sum
+ logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+ s"of size $size")
+ mergedLocalBlockBytes += size
+ } else {
+ remoteBlockBytes += blockInfos.map(_._2).sum
+ collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+ }
+ } else if (
+ Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
localBlocks ++= mergedBlockInfos.map(info => (info.blockId,
info.mapIndex))
localBlockBytes += mergedBlockInfos.map(_.size).sum
} else if (blockManager.hostLocalDirManager.isDefined &&
address.host == blockManager.blockManagerId.host) {
- checkBlockSizes(blockInfos)
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
val blocksForAddress =
mergedBlockInfos.map(info => (info.blockId, info.size,
info.mapIndex))
hostLocalBlocksByExecutor += address -> blocksForAddress
- hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+ hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info =>
(info._1, info._3))
hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
} else {
remoteBlockBytes += blockInfos.map(_._2).sum
collectFetchRequests(address, blockInfos, collectedRemoteRequests)
}
}
val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
- val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
- assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size +
numRemoteBlocks,
- s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the
number of local " +
- s"blocks ${localBlocks.size} + the number of host-local blocks
${hostLocalBlocks.size} " +
- s"+ the number of remote blocks ${numRemoteBlocks}.")
- logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)})
non-empty blocks " +
- s"including ${localBlocks.size}
(${Utils.bytesToString(localBlockBytes)}) local and " +
- s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)})
" +
- s"host-local and $numRemoteBlocks
(${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+ val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+ mergedLocalBlockBytes
+ val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+ assert(blocksToFetchCurrentIteration == localBlocks.size +
Review comment:
Note: Here we are assuming `localBlocks` is empty when method was
invoked.
##########
File path:
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the
push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+ private val iterator: ShuffleBlockFetcherIterator,
+ private val shuffleClient: BlockStoreClient,
+ private val blockManager: BlockManager,
+ private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+ SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+ blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+ /** A map for storing merged block shuffle chunk bitmap */
+ private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId,
RoaringBitmap]()
+
+ /**
+ * Returns true if the address is for a push-merged block.
+ */
+ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+ SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+ }
+
+ /**
+ * Returns true if the address is not of executor local or merged local
block. false otherwise.
+ */
+ def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+ (isMergedShuffleBlockAddress(address) && address.host !=
blockManager.blockManagerId.host) ||
+ (!isMergedShuffleBlockAddress(address) && address !=
blockManager.blockManagerId)
+ }
+
+ /**
+ * Returns true if the address if of merged local block. false otherwise.
+ */
+ def isMergedLocal(address: BlockManagerId): Boolean = {
+ isMergedShuffleBlockAddress(address) && address.host ==
blockManager.blockManagerId.host
+ }
+
+ def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+ chunksMetaMap(blockId).getCardinality
+ }
+
+ def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+ chunksMetaMap.remove(blockId)
+ }
+
+ def createChunkBlockInfosFromMetaResponse(
+ shuffleId: Int,
+ reduceId: Int,
+ blockSize: Long,
+ numChunks: Int,
+ bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+ val approxChunkSize = blockSize / numChunks
+ val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+ for (i <- 0 until numChunks) {
+ val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+ chunksMetaMap.put(blockChunkId, bitmaps(i))
+ logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+ blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+ }
+ blocksToFetch
+ }
+
+ def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+ val sizeMap = req.blocks.map {
+ case FetchBlockInfo(blockId, size, _) =>
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+ val address = req.address
+ val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+ override def onSuccess(shuffleId: Int, reduceId: Int, meta:
MergedBlockMeta): Unit = {
+ logInfo(s"Received the meta of merged block for ($shuffleId,
$reduceId) " +
+ s"from ${req.address.host}:${req.address.port}")
+ try {
+ iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId,
reduceId,
+ sizeMap((shuffleId, reduceId)), meta.getNumChunks,
meta.readChunkBitmaps(), address))
+ } catch {
+ case exception: Throwable =>
+ logError(s"Failed to parse the meta of merged block for
($shuffleId, $reduceId) " +
+ s"from ${req.address.host}:${req.address.port}", exception)
+ iterator.addToResultsQueue(
+ MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+ }
+ }
+
+ override def onFailure(shuffleId: Int, reduceId: Int, exception:
Throwable): Unit = {
+ logError(s"Failed to get the meta of merged block for ($shuffleId,
$reduceId) " +
+ s"from ${req.address.host}:${req.address.port}", exception)
+
iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId,
reduceId, address))
+ }
+ }
+ req.blocks.foreach { block =>
+ val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+ shuffleClient.getMergedBlockMeta(address.host, address.port,
shuffleBlockId.shuffleId,
+ shuffleBlockId.reduceId, mergedBlocksMetaListener)
+ }
+ }
+
+ // Fetch all outstanding merged local blocks
+ def fetchAllMergedLocalBlocks(
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+ if (mergedLocalBlocks.nonEmpty) {
+ blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_,
mergedLocalBlocks))
+ }
+ }
+
+ /**
+ * Fetch the merged blocks dirs if they are not in the cache and eventually
fetch merged local
+ * blocks.
+ */
+ private def fetchMergedLocalBlocks(
+ hostLocalDirManager: HostLocalDirManager,
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+ val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+ SHUFFLE_MERGER_IDENTIFIER)
+ if (cachedMergerDirs.isDefined) {
+ logDebug(s"Fetching local merged blocks with cached executors dir: " +
+ s"${cachedMergerDirs.get.mkString(", ")}")
+ mergedLocalBlocks.foreach(blockId =>
+ fetchMergedLocalBlock(blockId, cachedMergerDirs.get,
localShuffleMergerBlockMgrId))
+ } else {
+ logDebug(s"Asynchronous fetching local merged blocks without cached
executors dir")
+ hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+ localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+ case Success(dirs) =>
+ mergedLocalBlocks.takeWhile {
+ blockId =>
+ logDebug(s"Successfully fetched local dirs: " +
+ s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+ fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+ localShuffleMergerBlockMgrId)
+ }
+ logDebug(s"Got local merged blocks (without cached executors' dir)
in " +
+ s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() -
startTimeNs)} ms")
+ case Failure(throwable) =>
+ // If we see an exception with getting the local dirs for local
merged blocks,
+ // we fallback to fetch the original unmerged blocks. We do not
report block fetch
+ // failure.
+ logWarning(s"Error occurred while getting the local dirs for local
merged " +
+ s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original
blocks instead",
+ throwable)
+ mergedLocalBlocks.foreach(
+ blockId => iterator.addToResultsQueue(
+ IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0,
isNetworkReqDone = false))
+ )
+ }
+ }
+ }
+
+ /**
+ * Fetch a single local merged block generated.
+ * @param blockId ShuffleBlockId to be fetched
+ * @param localDirs Local directories where the merged shuffle files are
stored
+ * @param blockManagerId BlockManagerId
+ * @return Boolean represents successful or failed fetch
+ */
+ private[this] def fetchMergedLocalBlock(
+ blockId: BlockId,
+ localDirs: Array[String],
+ blockManagerId: BlockManagerId): Boolean = {
+ try {
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId,
localDirs)
+ .readChunkBitmaps()
+ // Fetch local merged shuffle block data as multiple chunks
+ val bufs: Seq[ManagedBuffer] =
blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+ // Update total number of blocks to fetch, reflecting the multiple local
chunks
+ iterator.foundMoreBlocksToFetch(bufs.size - 1)
+ for (chunkId <- bufs.indices) {
+ val buf = bufs(chunkId)
+ buf.retain()
+ val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+ shuffleBlockId.reduceId, chunkId)
+ iterator.addToResultsQueue(
+ SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
blockManagerId, buf.size(), buf,
+ isNetworkReqDone = false))
+ chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+ }
+ true
+ } catch {
+ case e: Exception =>
+ // If we see an exception with reading a local merged block, we
fallback to
+ // fetch the original unmerged blocks. We do not report block fetch
failure
+ // and will continue with the remaining local block read.
+ logWarning(s"Error occurred while fetching local merged block, " +
+ s"prepare to fetch the original blocks", e)
+ iterator.addToResultsQueue(
+ IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone =
false))
+ false
+ }
+ }
+
+ /**
+ * Initiate fetching fallback blocks for a merged block (or a merged block
chunk) that's failed
+ * to fetch.
+ * It calls out to map output tracker to get the list of original blocks for
the
+ * given merged blocks, split them into remote and local blocks, and process
them
+ * accordingly.
+ * The fallback happens when:
+ * 1. There is an exception while creating shuffle block chunk from local
merged shuffle block.
+ * See fetchLocalBlock.
+ * 2. There is a failure when fetching remote shuffle block chunks.
+ * 3. There is a failure when processing SuccessFetchResult which is for a
shuffle chunk
+ * (local or remote).
+ *
+ * @return number of blocks processed
+ */
+ def initiateFallbackBlockFetchForMergedBlock(
+ blockId: BlockId,
+ address: BlockManagerId): Int = {
+ logWarning(s"Falling back to fetch the original unmerged blocks for merged
block $blockId")
+ // Increase the blocks processed since we will process another block in
the next iteration of
+ // the while loop in ShuffleBlockFetcherIterator.next().
+ var blocksProcessed = 1
+ val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long,
Int)])] =
+ if (blockId.isShuffle) {
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ mapOutputTracker.getMapSizesForMergeResult(
+ shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+ } else {
+ val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+ val chunkBitmap: RoaringBitmap =
chunksMetaMap.remove(shuffleChunkId).orNull
+ // When there is a failure to fetch a remote merged shuffle block
chunk, then we try to
+ // fallback not only for that particular remote shuffle block chunk
but also for all the
+ // pending block chunks that belong to the same host. The reason for
doing so is that it is
+ // very likely that the subsequent requests for merged block chunks
from this host will fail
+ // as well. Since, push-based shuffle is best effort and we try not to
increase the delay
+ // of the fetches, we immediately fallback for all the pending shuffle
chunks in the
+ // fetchRequests queue.
+ if (isNotExecutorOrMergedLocal(address)) {
+ // Fallback for all the pending fetch requests
+ val pendingShuffleChunks =
iterator.removePendingChunks(shuffleChunkId, address)
+ if (pendingShuffleChunks.nonEmpty) {
+ pendingShuffleChunks.foreach { pendingBlockId =>
+ logWarning(s"Falling back immediately for merged block
$pendingBlockId")
+ val bitmapOfPendingChunk: RoaringBitmap =
+ chunksMetaMap.remove(pendingBlockId).orNull
+ assert(bitmapOfPendingChunk != null)
Review comment:
Any possibility of race here ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
}
}
- private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+ /**
+ * This is called from initialize and also from the fallback which is
triggered from
+ * [[PushBasedFetchHelper]].
+ */
+ private[this] def partitionBlocksByFetchMode(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId,
Seq[(BlockId, Long, Int)]],
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]):
ArrayBuffer[FetchRequest] = {
logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+ s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress:
$maxBlocksInFlightPerAddress")
- // Partition to local, host-local and remote blocks. Remote blocks are
further split into
- // FetchRequests of size at most maxBytesInFlight in order to limit the
amount of data in flight
+ // Partition to local, host-local, merged-local, remote (includes
merged-remote) blocks.
+ // Remote blocks are further split into FetchRequests of size at most
maxBytesInFlight in order
+ // to limit the amount of data in flight
val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+ val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId,
Int)]()
var localBlockBytes = 0L
var hostLocalBlockBytes = 0L
+ var mergedLocalBlockBytes = 0L
var remoteBlockBytes = 0L
+ val prevNumBlocksToFetch = numBlocksToFetch
val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
for ((address, blockInfos) <- blocksByAddress) {
- if (Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
- checkBlockSizes(blockInfos)
+ checkBlockSizes(blockInfos)
+ if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+ // These are push-based merged blocks or chunks of these merged blocks.
+ if (address.host == blockManager.blockManagerId.host) {
+ val pushMergedBlockInfos = blockInfos.map(
+ info => FetchBlockInfo(info._1, info._2, info._3))
+ numBlocksToFetch += pushMergedBlockInfos.size
+ mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+ val size = pushMergedBlockInfos.map(_.size).sum
+ logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+ s"of size $size")
+ mergedLocalBlockBytes += size
+ } else {
+ remoteBlockBytes += blockInfos.map(_._2).sum
+ collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+ }
+ } else if (
+ Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
localBlocks ++= mergedBlockInfos.map(info => (info.blockId,
info.mapIndex))
localBlockBytes += mergedBlockInfos.map(_.size).sum
} else if (blockManager.hostLocalDirManager.isDefined &&
address.host == blockManager.blockManagerId.host) {
- checkBlockSizes(blockInfos)
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
val blocksForAddress =
mergedBlockInfos.map(info => (info.blockId, info.size,
info.mapIndex))
hostLocalBlocksByExecutor += address -> blocksForAddress
- hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+ hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info =>
(info._1, info._3))
hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
} else {
remoteBlockBytes += blockInfos.map(_._2).sum
collectFetchRequests(address, blockInfos, collectedRemoteRequests)
}
}
val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
- val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
- assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size +
numRemoteBlocks,
- s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the
number of local " +
- s"blocks ${localBlocks.size} + the number of host-local blocks
${hostLocalBlocks.size} " +
- s"+ the number of remote blocks ${numRemoteBlocks}.")
- logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)})
non-empty blocks " +
- s"including ${localBlocks.size}
(${Utils.bytesToString(localBlockBytes)}) local and " +
- s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)})
" +
- s"host-local and $numRemoteBlocks
(${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+ val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+ mergedLocalBlockBytes
+ val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+ assert(blocksToFetchCurrentIteration == localBlocks.size +
+ hostLocalBlocksCurrentIteration.size + numRemoteBlocks +
mergedLocalBlocks.size,
+ s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't
equal to " +
+ s"the number of local blocks ${localBlocks.size} + " +
+ s"the number of host-local blocks
${hostLocalBlocksCurrentIteration.size} " +
+ s"the number of merged-local blocks ${mergedLocalBlocks.size} " +
+ s"+ the number of remote blocks ${numRemoteBlocks} ")
+ logInfo(s"Getting $blocksToFetchCurrentIteration " +
+ s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+ s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local
and " +
+ s"${hostLocalBlocksCurrentIteration.size}
(${Utils.bytesToString(hostLocalBlockBytes)}) " +
+ s"host-local and ${mergedLocalBlocks.size}
(${Utils.bytesToString(mergedLocalBlockBytes)}) " +
+ s"local merged and $numRemoteBlocks
(${Utils.bytesToString(remoteBlockBytes)}) " +
+ s"remote blocks")
+ if (hostLocalBlocksCurrentIteration.nonEmpty) {
+ this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration
+ }
collectedRemoteRequests
}
private def createFetchRequest(
blocks: Seq[FetchBlockInfo],
- address: BlockManagerId): FetchRequest = {
+ address: BlockManagerId,
+ forMergedMetas: Boolean = false): FetchRequest = {
logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address
"
+ s"with ${blocks.size} blocks")
- FetchRequest(address, blocks)
+ FetchRequest(address, blocks, forMergedMetas)
}
private def createFetchRequests(
curBlocks: Seq[FetchBlockInfo],
address: BlockManagerId,
isLast: Boolean,
- collectedRemoteRequests: ArrayBuffer[FetchRequest]): Seq[FetchBlockInfo]
= {
- val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks,
doBatchFetch)
+ collectedRemoteRequests: ArrayBuffer[FetchRequest],
+ enableBatchFetch: Boolean,
+ forMergedMetas: Boolean = false): Seq[FetchBlockInfo] = {
+ val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks,
enableBatchFetch)
Review comment:
Is `mergeContinuousShuffleBlockIdsIfNeeded` relevant for merged
blocks/chunks ?
If not, any side effects of doing this ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1054,82 @@ final class ShuffleBlockFetcherIterator(
"Failed to get block " + blockId + ", which is not a shuffle block",
e)
}
}
+
+ /**
+ * All the below methods are used by [[PushBasedFetchHelper]] to communicate
with the iterator
+ */
+ private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+ results.put(result)
+ }
+
+ private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+ numBlocksToFetch += moreBlocksToFetch
Review comment:
`foundMoreBlocksToFetch` -> `incrementNumBlocksToFetch` ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -436,24 +485,48 @@ final class ShuffleBlockFetcherIterator(
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = Seq.empty[FetchBlockInfo]
-
while (iterator.hasNext) {
val (blockId, size, mapIndex) = iterator.next()
- assertPositiveBlockSize(blockId, size)
curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex))
curRequestSize += size
- // For batch fetch, the actual block in flight should count for merged
block.
- val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >=
maxBlocksInFlightPerAddress
- if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
- curBlocks = createFetchRequests(curBlocks, address, isLast = false,
- collectedRemoteRequests)
- curRequestSize = curBlocks.map(_.size).sum
+ blockId match {
+ // Either all blocks are merged blocks, merged block chunks, or
original non-merged blocks.
+ // Based on these types, we decide to do batch fetch and create
FetchRequests with
+ // forMergedMetas set.
+ case ShuffleBlockChunkId(_, _, _) =>
+ if (curRequestSize >= targetRemoteRequestSize ||
+ curBlocks.size >= maxBlocksInFlightPerAddress) {
+ curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+ collectedRemoteRequests, enableBatchFetch = false)
+ curRequestSize = curBlocks.map(_.size).sum
+ }
+ case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+ if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+ curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+ collectedRemoteRequests, enableBatchFetch = false,
forMergedMetas = true)
+ }
+ case _ =>
+ // For batch fetch, the actual block in flight should count for
merged block.
+ val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >=
maxBlocksInFlightPerAddress
+ if (curRequestSize >= targetRemoteRequestSize ||
mayExceedsMaxBlocks) {
+ curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+ collectedRemoteRequests, enableBatchFetch = doBatchFetch)
+ curRequestSize = curBlocks.map(_.size).sum
+ }
}
}
// Add in the final request
if (curBlocks.nonEmpty) {
+ val (enableBatchFetch, areMergedBlocks) = {
+ curBlocks.head.blockId match {
+ case ShuffleBlockChunkId(_, _, _) => (false, false)
+ case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
+ case _ => (doBatchFetch, false)
+ }
+ }
curBlocks = createFetchRequests(curBlocks, address, isLast = true,
- collectedRemoteRequests)
+ collectedRemoteRequests, enableBatchFetch = enableBatchFetch,
+ forMergedMetas = areMergedBlocks)
curRequestSize = curBlocks.map(_.size).sum
Review comment:
nit: Unrelated to this PR, but drop this `sum` ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the
push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+ private val iterator: ShuffleBlockFetcherIterator,
+ private val shuffleClient: BlockStoreClient,
+ private val blockManager: BlockManager,
+ private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+ private[this] val startTimeNs = System.nanoTime()
+
+ private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+ SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+ blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+ /** A map for storing merged block shuffle chunk bitmap */
+ private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId,
RoaringBitmap]()
+
+ /**
+ * Returns true if the address is for a push-merged block.
+ */
+ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+ SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+ }
+
+ /**
+ * Returns true if the address is not of executor local or merged local
block. false otherwise.
+ */
+ def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+ (isMergedShuffleBlockAddress(address) && address.host !=
blockManager.blockManagerId.host) ||
+ (!isMergedShuffleBlockAddress(address) && address !=
blockManager.blockManagerId)
+ }
+
+ /**
+ * Returns true if the address if of merged local block. false otherwise.
+ */
+ def isMergedLocal(address: BlockManagerId): Boolean = {
+ isMergedShuffleBlockAddress(address) && address.host ==
blockManager.blockManagerId.host
+ }
+
+ def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+ chunksMetaMap(blockId).getCardinality
+ }
+
+ def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+ chunksMetaMap.remove(blockId)
+ }
+
+ def createChunkBlockInfosFromMetaResponse(
+ shuffleId: Int,
+ reduceId: Int,
+ blockSize: Long,
+ numChunks: Int,
+ bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+ val approxChunkSize = blockSize / numChunks
+ val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+ for (i <- 0 until numChunks) {
+ val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+ chunksMetaMap.put(blockChunkId, bitmaps(i))
+ logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+ blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+ }
+ blocksToFetch
+ }
+
+ def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+ val sizeMap = req.blocks.map {
+ case FetchBlockInfo(blockId, size, _) =>
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+ val address = req.address
+ val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+ override def onSuccess(shuffleId: Int, reduceId: Int, meta:
MergedBlockMeta): Unit = {
+ logInfo(s"Received the meta of merged block for ($shuffleId,
$reduceId) " +
+ s"from ${req.address.host}:${req.address.port}")
+ try {
+ iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId,
reduceId,
+ sizeMap((shuffleId, reduceId)), meta.getNumChunks,
meta.readChunkBitmaps(), address))
+ } catch {
+ case exception: Throwable =>
+ logError(s"Failed to parse the meta of merged block for
($shuffleId, $reduceId) " +
+ s"from ${req.address.host}:${req.address.port}", exception)
+ iterator.addToResultsQueue(
+ MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+ }
+ }
+
+ override def onFailure(shuffleId: Int, reduceId: Int, exception:
Throwable): Unit = {
+ logError(s"Failed to get the meta of merged block for ($shuffleId,
$reduceId) " +
+ s"from ${req.address.host}:${req.address.port}", exception)
+
iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId,
reduceId, address))
+ }
+ }
+ req.blocks.foreach { block =>
+ val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+ shuffleClient.getMergedBlockMeta(address.host, address.port,
shuffleBlockId.shuffleId,
+ shuffleBlockId.reduceId, mergedBlocksMetaListener)
+ }
+ }
+
+ // Fetch all outstanding merged local blocks
+ def fetchAllMergedLocalBlocks(
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+ if (mergedLocalBlocks.nonEmpty) {
+ blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_,
mergedLocalBlocks))
+ }
+ }
+
+ /**
+ * Fetch the merged blocks dirs if they are not in the cache and eventually
fetch merged local
+ * blocks.
+ */
+ private def fetchMergedLocalBlocks(
+ hostLocalDirManager: HostLocalDirManager,
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+ val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+ SHUFFLE_MERGER_IDENTIFIER)
+ if (cachedMergerDirs.isDefined) {
+ logDebug(s"Fetching local merged blocks with cached executors dir: " +
+ s"${cachedMergerDirs.get.mkString(", ")}")
+ mergedLocalBlocks.foreach(blockId =>
+ fetchMergedLocalBlock(blockId, cachedMergerDirs.get,
localShuffleMergerBlockMgrId))
+ } else {
+ logDebug(s"Asynchronous fetching local merged blocks without cached
executors dir")
+ hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+ localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+ case Success(dirs) =>
+ mergedLocalBlocks.takeWhile {
+ blockId =>
+ logDebug(s"Successfully fetched local dirs: " +
+ s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+ fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+ localShuffleMergerBlockMgrId)
+ }
+ logDebug(s"Got local merged blocks (without cached executors' dir)
in " +
+ s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() -
startTimeNs)} ms")
+ case Failure(throwable) =>
+ // If we see an exception with getting the local dirs for local
merged blocks,
+ // we fallback to fetch the original unmerged blocks. We do not
report block fetch
+ // failure.
+ logWarning(s"Error occurred while getting the local dirs for local
merged " +
+ s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original
blocks instead",
+ throwable)
+ mergedLocalBlocks.foreach(
+ blockId => iterator.addToResultsQueue(
+ IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0,
isNetworkReqDone = false))
+ )
+ }
+ }
+ }
+
+ /**
+ * Fetch a single local merged block generated.
+ * @param blockId ShuffleBlockId to be fetched
+ * @param localDirs Local directories where the merged shuffle files are
stored
+ * @param blockManagerId BlockManagerId
+ * @return Boolean represents successful or failed fetch
+ */
+ private[this] def fetchMergedLocalBlock(
+ blockId: BlockId,
+ localDirs: Array[String],
+ blockManagerId: BlockManagerId): Boolean = {
+ try {
+ val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+ val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId,
localDirs)
+ .readChunkBitmaps()
+ // Fetch local merged shuffle block data as multiple chunks
+ val bufs: Seq[ManagedBuffer] =
blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+ // Update total number of blocks to fetch, reflecting the multiple local
chunks
+ iterator.foundMoreBlocksToFetch(bufs.size - 1)
+ for (chunkId <- bufs.indices) {
+ val buf = bufs(chunkId)
+ buf.retain()
+ val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+ shuffleBlockId.reduceId, chunkId)
+ iterator.addToResultsQueue(
+ SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
blockManagerId, buf.size(), buf,
+ isNetworkReqDone = false))
+ chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+ }
+ true
+ } catch {
+ case e: Exception =>
+ // If we see an exception with reading a local merged block, we
fallback to
+ // fetch the original unmerged blocks. We do not report block fetch
failure
+ // and will continue with the remaining local block read.
+ logWarning(s"Error occurred while fetching local merged block, " +
+ s"prepare to fetch the original blocks", e)
+ iterator.addToResultsQueue(
+ IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone =
false))
+ false
+ }
+ }
+
+ /**
+ * Initiate fetching fallback blocks for a merged block (or a merged block
chunk) that's failed
+ * to fetch.
+ * It calls out to map output tracker to get the list of original blocks for
the
+ * given merged blocks, split them into remote and local blocks, and process
them
+ * accordingly.
+ * The fallback happens when:
+ * 1. There is an exception while creating shuffle block chunk from local
merged shuffle block.
+ * See fetchLocalBlock.
+ * 2. There is a failure when fetching remote shuffle block chunks.
+ * 3. There is a failure when processing SuccessFetchResult which is for a
shuffle chunk
+ * (local or remote).
+ *
+ * @return number of blocks processed
+ */
+ def initiateFallbackBlockFetchForMergedBlock(
+ blockId: BlockId,
+ address: BlockManagerId): Int = {
Review comment:
We have possibility of only `ShuffleBlockId` or `ShuffleBlockChunkId` in
this method right ?
Add that as a precondition and check for `isInstanceOf[ShuffleBlockId]`
instead of `isShuffle` ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1054,82 @@ final class ShuffleBlockFetcherIterator(
"Failed to get block " + blockId + ", which is not a shuffle block",
e)
}
}
+
+ /**
+ * All the below methods are used by [[PushBasedFetchHelper]] to communicate
with the iterator
+ */
+ private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+ results.put(result)
+ }
+
+ private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+ numBlocksToFetch += moreBlocksToFetch
+ }
+
+ /**
+ * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when
there is a fetch
+ * failure with a shuffle merged block/chunk.
+ */
+ private[storage] def fetchFallbackBlocks(
+ fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long,
Int)])]): Unit = {
+ val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+ val fallbackHostLocalBlocksByExecutor =
+ mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+ val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+ val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr,
+ fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor,
fallbackMergedLocalBlocks)
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(fallbackRemoteReqs)
+ logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for
merged")
+ // If there is any fall back block that's a local block, we get them here.
The original
+ // invocation to fetchLocalBlocks might have already returned by this
time, so we need
+ // to invoke it again here.
Review comment:
Can we rephrase this comment ? The comments (`"
The original invocation to fetchLocalBlocks might have already returned by
this time"`) makes it sound like a timing issue and so potentially a race.
In reality, initial `fetchLocalBlocks` was for the initial request, and for
each failure to fetch merged blocks/chunks, we have to redo the exercise for
that set.
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1054,82 @@ final class ShuffleBlockFetcherIterator(
"Failed to get block " + blockId + ", which is not a shuffle block",
e)
}
}
+
+ /**
+ * All the below methods are used by [[PushBasedFetchHelper]] to communicate
with the iterator
+ */
+ private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+ results.put(result)
+ }
+
+ private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+ numBlocksToFetch += moreBlocksToFetch
+ }
+
+ /**
+ * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when
there is a fetch
+ * failure with a shuffle merged block/chunk.
+ */
+ private[storage] def fetchFallbackBlocks(
+ fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long,
Int)])]): Unit = {
+ val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+ val fallbackHostLocalBlocksByExecutor =
+ mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+ val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+ val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr,
+ fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor,
fallbackMergedLocalBlocks)
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(fallbackRemoteReqs)
+ logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for
merged")
+ // If there is any fall back block that's a local block, we get them here.
The original
+ // invocation to fetchLocalBlocks might have already returned by this
time, so we need
+ // to invoke it again here.
+ fetchLocalBlocks(fallbackLocalBlocks)
+ // Merged local blocks should be empty during fallback
+ assert(fallbackMergedLocalBlocks.isEmpty,
+ "There should be zero merged blocks during fallback")
+ // Some of the fallback local blocks could be host local blocks
+ fetchAllHostLocalBlocks(fallbackHostLocalBlocksByExecutor)
+ }
+
+ /**
+ * Removes all the pending shuffle chunks that are on the same host as the
block chunk that had
+ * a fetch failure.
+ *
+ * @return set of all the removed shuffle chunk Ids.
+ */
+ private[storage] def removePendingChunks(
+ failedBlockId: ShuffleBlockChunkId,
+ address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = {
+ val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]()
+
+ def sameShuffleBlockChunk(block: BlockId): Boolean = {
+ val chunkId = block.asInstanceOf[ShuffleBlockChunkId]
+ chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId ==
failedBlockId.reduceId
+ }
+
+ def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = {
+ val fetchRequestsToRemove = new mutable.Queue[FetchRequest]()
+ fetchRequestsToRemove ++= queue.dequeueAll(req => {
+ val firstBlock = req.blocks.head
+ firstBlock.blockId.isShuffleChunk && req.address.equals(address) &&
+ sameShuffleBlockChunk(firstBlock.blockId)
+ })
+ fetchRequestsToRemove.foreach(req => {
+ removedChunkIds ++=
req.blocks.iterator.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])
+ })
+ }
+
+ filterRequests(fetchRequests)
+ val defRequests = deferredFetchRequests.remove(address).orNull
+ if (defRequests != null) {
+ filterRequests(defRequests)
+ if (defRequests.nonEmpty) {
+ deferredFetchRequests(address) = defRequests
+ }
+ }
Review comment:
nit:
```suggestion
deferredFetchRequests.get(address).foreach(defRequests => {
filterRequests(defRequests)
if (defRequests.isEmpty) deferredFetchRequests.remove(address)
})
```
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +360,48 @@ final class ShuffleBlockFetcherIterator(
}
}
- private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+ /**
+ * This is called from initialize and also from the fallback which is
triggered from
+ * [[PushBasedFetchHelper]].
+ */
+ private[this] def partitionBlocksByFetchMode(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+ hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId,
Seq[(BlockId, Long, Int)]],
+ mergedLocalBlocks: mutable.LinkedHashSet[BlockId]):
ArrayBuffer[FetchRequest] = {
logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+ s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress:
$maxBlocksInFlightPerAddress")
- // Partition to local, host-local and remote blocks. Remote blocks are
further split into
- // FetchRequests of size at most maxBytesInFlight in order to limit the
amount of data in flight
+ // Partition to local, host-local, merged-local, remote (includes
merged-remote) blocks.
+ // Remote blocks are further split into FetchRequests of size at most
maxBytesInFlight in order
+ // to limit the amount of data in flight
val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+ val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId,
Int)]()
var localBlockBytes = 0L
var hostLocalBlockBytes = 0L
+ var mergedLocalBlockBytes = 0L
var remoteBlockBytes = 0L
+ val prevNumBlocksToFetch = numBlocksToFetch
val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
for ((address, blockInfos) <- blocksByAddress) {
- if (Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
+ if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+ // These are push-based merged blocks or chunks of these merged blocks.
+ if (address.host == blockManager.blockManagerId.host) {
+ checkBlockSizes(blockInfos)
+ val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+ blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
doBatchFetch = false)
+ numBlocksToFetch += pushMergedBlockInfos.size
+ mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+ mergedLocalBlockBytes += pushMergedBlockInfos.map(_.size).sum
+ logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+ s"of size $mergedLocalBlockBytes")
+ } else {
+ remoteBlockBytes += blockInfos.map(_._2).sum
+ collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+ }
+ } else if (
+ Seq(blockManager.blockManagerId.executorId,
fallback).contains(address.executorId)) {
Review comment:
While we are at it, make it a `Set` ?
##########
File path:
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -22,31 +22,40 @@ import java.nio.ByteBuffer
import java.util.UUID
import java.util.concurrent.{CompletableFuture, Semaphore}
+import scala.collection.mutable
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import io.netty.util.internal.OutOfDirectMemoryError
import org.mockito.ArgumentMatchers.{any, eq => meq}
-import org.mockito.Mockito.{mock, times, verify, when}
+import org.mockito.Mockito.{doThrow, mock, times, verify, when}
+import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
+import org.roaringbitmap.RoaringBitmap
import org.scalatest.PrivateMethodTester
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
import org.apache.spark.network._
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer,
ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener,
DownloadFileManager, ExternalBlockStoreClient}
+import org.apache.spark.network.shuffle.{BlockFetchingListener,
DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta,
MergedBlocksMetaListener}
import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.shuffle.{FetchFailedException,
ShuffleReadMetricsReporter}
-import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
import org.apache.spark.util.Utils
class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with
PrivateMethodTester {
Review comment:
Also add tests for:
a) deserialization failure results in initiating fallback.
b) fetch failure of both merged block and fallback block should get reported
to driver as fetch failure.
Are these handled already ?
##########
File path:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1394,54 @@ object ShuffleBlockFetcherIterator {
*/
private[storage]
case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends
FetchResult
+
+ /**
+ * Result of an un-successful fetch of either of these:
+ * 1) Remote shuffle block chunk.
+ * 2) Local merged block data.
+ *
+ * Instead of treating this as a FailureFetchResult, we ignore this failure
+ * and fallback to fetch the original unmerged blocks.
+ * @param blockId block id
+ * @param address BlockManager that the merged block was attempted to be
fetched from
+ * @param size size of the block, used to update bytesInFlight.
+ * @param isNetworkReqDone Is this the last network request for this host in
this fetch
+ * request. Used to update reqsInFlight.
+ */
+ private[storage] case class IgnoreFetchResult(blockId: BlockId,
Review comment:
We are not ignoring the result as such, but using it to initiate a
fallback.
`IgnoreFetchResult` -> `RetriableMergeFailureResult` ? Or something better ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]