mridulm commented on code in PR #42296:
URL: https://github.com/apache/spark/pull/42296#discussion_r1325310110
##########
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala:
##########
@@ -355,8 +367,28 @@ final class ShuffleBlockFetcherIterator(
updateMergedReqsDuration(wasReqForMergedChunks = true)
results.put(FallbackOnPushMergedFailureResult(
block, address, infoMap(blockId)._1,
remainingBlocks.isEmpty))
- } else {
+ } else if (!shouldPerformShuffleLocationRefresh) {
results.put(FailureFetchResult(block, infoMap(blockId)._2,
address, e))
+ } else {
+ val (shuffleId, mapId) = BlockId.getShuffleIdAndMapId(block)
Review Comment:
Should we move the `getShuffleIdAndMapId` into the `Try` ?
We will effectively block shuffle indefinitely in case
`getShuffleIdAndMapId` throws an exception (it should not currently - but code
could evolve).
Something like:
```
Try {
val (shuffleId, mapId) =
BlockId.getShuffleIdAndMapId(block)
mapOutputTrackerWorker
.getMapOutputLocationWithRefresh(shuffleId, mapId,
address)
} match {
```
##########
core/src/main/scala/org/apache/spark/TestUtils.scala:
##########
@@ -491,6 +491,28 @@ private[spark] object TestUtils {
EnumSet.of(OWNER_READ, OWNER_EXECUTE, OWNER_WRITE))
file.getPath
}
+
+ /** Sets all configs specified in `confPairs`, calls `f`, and then restores
them. */
+ def withConf[T](confPairs: (String, String)*)(f: => T): T = {
+ val conf = SparkEnv.get.conf
+ val (keys, values) = confPairs.unzip
+ val currentValues = keys.map { key =>
+ if (conf.contains(key)) {
+ Some(conf.get(key))
+ } else {
+ None
+ }
+ }
+ (keys, values).zipped.foreach { (key, value) =>
+ conf.set(key, value)
+ }
+ try f finally {
+ keys.zip(currentValues).foreach {
+ case (key, Some(value)) => conf.set(key, value)
+ case (key, None) => conf.remove(key)
+ }
+ }
+ }
Review Comment:
nit:
```suggestion
def withConf[T](confPairs: (String, String)*)(f: => T): T = {
val conf = SparkEnv.get.conf
val inputConfMap = confPairs.toMap
val modifiedValues = conf.getAll.filter(kv =>
inputConfMap.contains(kv._1)).toMap
inputConfMap.foreach { kv =>
conf.set(kv._1, kv._2)
}
try f finally {
inputConfMap.keys.foreach { key =>
if (modifiedValues.contains(key)) {
conf.set(key, modifiedValues(key))
} else {
conf.remove(key)
}
}
}
}
```
##########
core/src/main/scala/org/apache/spark/MapOutputTracker.scala:
##########
@@ -1288,6 +1288,32 @@ private[spark] class MapOutputTrackerWorker(conf:
SparkConf) extends MapOutputTr
mapSizesByExecutorId.iter
}
+ def getMapOutputLocationWithRefresh(
+ shuffleId: Int,
+ mapId: Long,
+ prevLocation: BlockManagerId): BlockManagerId = {
+ // Try to get the cached location first in case other concurrent tasks
+ // fetched the fresh location already
+ var currentLocationOpt = getMapOutputLocation(shuffleId, mapId)
+ if (currentLocationOpt.contains(prevLocation)) {
+ // Address in the cache unchanged. Try to clean cache and get a fresh
location
+ unregisterShuffle(shuffleId)
+ currentLocationOpt = getMapOutputLocation(shuffleId, mapId,
canFetchMergeResult = true)
Review Comment:
```suggestion
currentLocationOpt = getMapOutputLocation(shuffleId, mapId,
fetchMergeResult)
```
##########
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala:
##########
@@ -664,6 +667,107 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}
+ test("handle map output location change") {
+ TestUtils.withConf(
+ "spark.decommission.enabled" -> "true",
+ "spark.storage.decommission.enabled" -> "true",
+ "spark.storage.decommission.shuffleBlocks.enabled" -> "true"
+ ) {
+ val blockManager = createMockBlockManager()
+ val localBmId = blockManager.blockManagerId
+ val remoteBmId = BlockManagerId("test-remote-1", "test-remote-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+ )
+
+ doReturn(blocks(ShuffleBlockId(0, 2, 0))).when(blockManager)
+ .getLocalBlockData(meq(ShuffleBlockId(0, 2, 0)))
+
+ answerFetchBlocks { invocation =>
+ val host = invocation.getArgument[String](0)
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+ host match {
+ case "test-remote-1" =>
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0,
0)))
+ // TODO: update exception type here
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 1, 0).toString, new RuntimeException())
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 2, 0).toString, new RuntimeException())
+ case "test-remote-2" =>
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1,
0)))
+ case "test-local-host" =>
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2,
0)))
+ }
+ }
+
+ when(mapOutputTracker.getMapOutputLocationWithRefresh(any(), any(),
any()))
+ .thenAnswer { invocation =>
+ val mapId = invocation.getArgument[Long](1)
+ mapId match {
+ case 0 => BlockManagerId("test-remote-1", "test-remote-1", 2)
+ case 1 => BlockManagerId("test-remote-2", "test-remote-2", 2)
+ case 2 => localBmId
+ }
+ }
+
+ Seq(true, false).foreach { isEnabled =>
+ SparkEnv.get.conf.set(
+ "spark.storage.decommission.shuffleBlocks.refreshLocationsEnabled",
isEnabled.toString)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+ blockManager = Some(blockManager)
+ )
+ if (isEnabled) {
+ assert(iterator.map(_._1).toSet == blocks.keys)
+ val bytesInFlightAccess =
PrivateMethod[Long](Symbol("bytesInFlight"))
+ assert(iterator.invokePrivate(bytesInFlightAccess()) == 0)
+ } else {
+ intercept[FetchFailedException] {
+ iterator.toList
+ }
+ }
+ }
+ }
+ }
+
+
+ test("metadata fetch failure in handle map output location change") {
+ TestUtils.withConf(
+ "spark.decommission.enabled" -> "true",
+ "spark.storage.decommission.enabled" -> "true",
+ "spark.storage.decommission.shuffleBlocks.enabled" -> "true",
+ "spark.storage.decommission.shuffleBlocks.refreshLocationsEnabled" ->
"true"
+ ) {
+ val remoteBmId = BlockManagerId("test-remote-1", "test-remote-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()
+ )
+ answerFetchBlocks { invocation =>
+ val host = invocation.getArgument[String](0)
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+ host match {
+ case "test-remote-1" =>
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 0, 0).toString, new RuntimeException())
+ }
+ }
+ when(mapOutputTracker.getMapOutputLocationWithRefresh(any(), any(),
any()))
+ .thenAnswer(_ => throw new MetadataFetchFailedException(0, 0, ""))
Review Comment:
super nit: set `mapId` to `-1`
##########
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala:
##########
@@ -664,6 +667,107 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}
+ test("handle map output location change") {
+ TestUtils.withConf(
+ "spark.decommission.enabled" -> "true",
+ "spark.storage.decommission.enabled" -> "true",
+ "spark.storage.decommission.shuffleBlocks.enabled" -> "true"
+ ) {
+ val blockManager = createMockBlockManager()
+ val localBmId = blockManager.blockManagerId
+ val remoteBmId = BlockManagerId("test-remote-1", "test-remote-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+ )
+
+ doReturn(blocks(ShuffleBlockId(0, 2, 0))).when(blockManager)
+ .getLocalBlockData(meq(ShuffleBlockId(0, 2, 0)))
+
+ answerFetchBlocks { invocation =>
+ val host = invocation.getArgument[String](0)
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+ host match {
+ case "test-remote-1" =>
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0,
0)))
+ // TODO: update exception type here
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 1, 0).toString, new RuntimeException())
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 2, 0).toString, new RuntimeException())
+ case "test-remote-2" =>
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1,
0)))
+ case "test-local-host" =>
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2,
0)))
+ }
+ }
+
+ when(mapOutputTracker.getMapOutputLocationWithRefresh(any(), any(),
any()))
+ .thenAnswer { invocation =>
+ val mapId = invocation.getArgument[Long](1)
+ mapId match {
+ case 0 => BlockManagerId("test-remote-1", "test-remote-1", 2)
+ case 1 => BlockManagerId("test-remote-2", "test-remote-2", 2)
+ case 2 => localBmId
+ }
+ }
+
+ Seq(true, false).foreach { isEnabled =>
+ SparkEnv.get.conf.set(
+ "spark.storage.decommission.shuffleBlocks.refreshLocationsEnabled",
isEnabled.toString)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+ blockManager = Some(blockManager)
+ )
+ if (isEnabled) {
+ assert(iterator.map(_._1).toSet == blocks.keys)
+ val bytesInFlightAccess =
PrivateMethod[Long](Symbol("bytesInFlight"))
+ assert(iterator.invokePrivate(bytesInFlightAccess()) == 0)
+ } else {
+ intercept[FetchFailedException] {
+ iterator.toList
+ }
+ }
+ }
+ }
+ }
+
+
+ test("metadata fetch failure in handle map output location change") {
Review Comment:
Can you also add a simple test to ensure existing behavior is preserved when
no migration has happened ?
##########
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala:
##########
@@ -264,18 +272,22 @@ final class ShuffleBlockFetcherIterator(
case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString,
(size, mapIndex))
}.toMap
val remainingBlocks = new HashSet[String]() ++= infoMap.keys
- val deferredBlocks = new ArrayBuffer[String]()
+ val deferredBlocks = new HashMap[BlockManagerId, Queue[String]]()
val blockIds = req.blocks.map(_.blockId.toString)
val address = req.address
val requestStartTime = clock.nanoTime()
@inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
- val blocks = deferredBlocks.map { blockId =>
- val (size, mapIndex) = infoMap(blockId)
- FetchBlockInfo(BlockId(blockId), size, mapIndex)
+ val newAddressToBlocks = new HashMap[BlockManagerId,
Queue[FetchBlockInfo]]()
+ deferredBlocks.foreach { case (blockManagerId, blockIds) =>
+ val blocks = blockIds.map { blockId =>
+ val (size, mapIndex) = infoMap(blockId)
+ FetchBlockInfo(BlockId(blockId), size, mapIndex)
+ }
+ newAddressToBlocks.put(blockManagerId, blocks)
}
Review Comment:
Scratch this - required for decrementing accounting details.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]