mridulm commented on code in PR #42296:
URL: https://github.com/apache/spark/pull/42296#discussion_r1325310110


##########
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala:
##########
@@ -355,8 +367,28 @@ final class ShuffleBlockFetcherIterator(
                 updateMergedReqsDuration(wasReqForMergedChunks = true)
                 results.put(FallbackOnPushMergedFailureResult(
                   block, address, infoMap(blockId)._1, 
remainingBlocks.isEmpty))
-              } else {
+              } else if (!shouldPerformShuffleLocationRefresh) {
                 results.put(FailureFetchResult(block, infoMap(blockId)._2, 
address, e))
+              } else {
+                val (shuffleId, mapId) = BlockId.getShuffleIdAndMapId(block)

Review Comment:
   Should we move the `getShuffleIdAndMapId` into the `Try` ?
   We will effectively block shuffle indefinitely in case 
`getShuffleIdAndMapId` throws an exception (it should not currently - but code 
could evolve).
   
   Something like:
   ```
                   Try {
                     val (shuffleId, mapId) = 
BlockId.getShuffleIdAndMapId(block)
                     mapOutputTrackerWorker
                       .getMapOutputLocationWithRefresh(shuffleId, mapId, 
address)
                   } match {
   ```



##########
core/src/main/scala/org/apache/spark/TestUtils.scala:
##########
@@ -491,6 +491,28 @@ private[spark] object TestUtils {
       EnumSet.of(OWNER_READ, OWNER_EXECUTE, OWNER_WRITE))
     file.getPath
   }
+
+  /** Sets all configs specified in `confPairs`, calls `f`, and then restores 
them. */
+  def withConf[T](confPairs: (String, String)*)(f: => T): T = {
+    val conf = SparkEnv.get.conf
+    val (keys, values) = confPairs.unzip
+    val currentValues = keys.map { key =>
+      if (conf.contains(key)) {
+        Some(conf.get(key))
+      } else {
+        None
+      }
+    }
+    (keys, values).zipped.foreach { (key, value) =>
+      conf.set(key, value)
+    }
+    try f finally {
+      keys.zip(currentValues).foreach {
+        case (key, Some(value)) => conf.set(key, value)
+        case (key, None) => conf.remove(key)
+      }
+    }
+  }

Review Comment:
   nit:
   ```suggestion
     def withConf[T](confPairs: (String, String)*)(f: => T): T = {
       val conf = SparkEnv.get.conf
       val inputConfMap = confPairs.toMap
       val modifiedValues = conf.getAll.filter(kv => 
inputConfMap.contains(kv._1)).toMap
       inputConfMap.foreach { kv =>
         conf.set(kv._1, kv._2)
       }
       try f finally {
         inputConfMap.keys.foreach { key =>
           if (modifiedValues.contains(key)) {
             conf.set(key, modifiedValues(key))
           } else {
             conf.remove(key)
           }
         }
       }
     }
   ```



##########
core/src/main/scala/org/apache/spark/MapOutputTracker.scala:
##########
@@ -1288,6 +1288,32 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
     mapSizesByExecutorId.iter
   }
 
+  def getMapOutputLocationWithRefresh(
+      shuffleId: Int,
+      mapId: Long,
+      prevLocation: BlockManagerId): BlockManagerId = {
+    // Try to get the cached location first in case other concurrent tasks
+    // fetched the fresh location already
+    var currentLocationOpt = getMapOutputLocation(shuffleId, mapId)
+    if (currentLocationOpt.contains(prevLocation)) {
+      // Address in the cache unchanged. Try to clean cache and get a fresh 
location
+      unregisterShuffle(shuffleId)
+      currentLocationOpt = getMapOutputLocation(shuffleId, mapId, 
canFetchMergeResult = true)

Review Comment:
   ```suggestion
         currentLocationOpt = getMapOutputLocation(shuffleId, mapId, 
fetchMergeResult)
   ```



##########
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala:
##########
@@ -664,6 +667,107 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
   }
 
+  test("handle map output location change") {
+    TestUtils.withConf(
+      "spark.decommission.enabled" -> "true",
+      "spark.storage.decommission.enabled" -> "true",
+      "spark.storage.decommission.shuffleBlocks.enabled" -> "true"
+    ) {
+      val blockManager = createMockBlockManager()
+      val localBmId = blockManager.blockManagerId
+      val remoteBmId = BlockManagerId("test-remote-1", "test-remote-1", 2)
+      val blocks = Map[BlockId, ManagedBuffer](
+        ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+        ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+        ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+      )
+
+      doReturn(blocks(ShuffleBlockId(0, 2, 0))).when(blockManager)
+        .getLocalBlockData(meq(ShuffleBlockId(0, 2, 0)))
+
+      answerFetchBlocks { invocation =>
+        val host = invocation.getArgument[String](0)
+        val listener = invocation.getArgument[BlockFetchingListener](4)
+        host match {
+          case "test-remote-1" =>
+            listener.onBlockFetchSuccess(
+              ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 
0)))
+            // TODO: update exception type here
+            listener.onBlockFetchFailure(
+              ShuffleBlockId(0, 1, 0).toString, new RuntimeException())
+            listener.onBlockFetchFailure(
+              ShuffleBlockId(0, 2, 0).toString, new RuntimeException())
+          case "test-remote-2" =>
+            listener.onBlockFetchSuccess(
+              ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 
0)))
+          case "test-local-host" =>
+            listener.onBlockFetchSuccess(
+              ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 
0)))
+        }
+      }
+
+      when(mapOutputTracker.getMapOutputLocationWithRefresh(any(), any(), 
any()))
+        .thenAnswer { invocation =>
+          val mapId = invocation.getArgument[Long](1)
+          mapId match {
+            case 0 => BlockManagerId("test-remote-1", "test-remote-1", 2)
+            case 1 => BlockManagerId("test-remote-2", "test-remote-2", 2)
+            case 2 => localBmId
+          }
+        }
+
+      Seq(true, false).foreach { isEnabled =>
+        SparkEnv.get.conf.set(
+          "spark.storage.decommission.shuffleBlocks.refreshLocationsEnabled", 
isEnabled.toString)
+        val iterator = createShuffleBlockIteratorWithDefaults(
+          Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+          blockManager = Some(blockManager)
+        )
+        if (isEnabled) {
+          assert(iterator.map(_._1).toSet == blocks.keys)
+          val bytesInFlightAccess = 
PrivateMethod[Long](Symbol("bytesInFlight"))
+          assert(iterator.invokePrivate(bytesInFlightAccess()) == 0)
+        } else {
+          intercept[FetchFailedException] {
+            iterator.toList
+          }
+        }
+      }
+    }
+  }
+
+
+  test("metadata fetch failure in handle map output location change") {
+    TestUtils.withConf(
+      "spark.decommission.enabled" -> "true",
+      "spark.storage.decommission.enabled" -> "true",
+      "spark.storage.decommission.shuffleBlocks.enabled" -> "true",
+      "spark.storage.decommission.shuffleBlocks.refreshLocationsEnabled" -> 
"true"
+    ) {
+      val remoteBmId = BlockManagerId("test-remote-1", "test-remote-1", 2)
+      val blocks = Map[BlockId, ManagedBuffer](
+        ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()
+      )
+      answerFetchBlocks { invocation =>
+        val host = invocation.getArgument[String](0)
+        val listener = invocation.getArgument[BlockFetchingListener](4)
+        host match {
+          case "test-remote-1" =>
+            listener.onBlockFetchFailure(
+              ShuffleBlockId(0, 0, 0).toString, new RuntimeException())
+        }
+      }
+      when(mapOutputTracker.getMapOutputLocationWithRefresh(any(), any(), 
any()))
+        .thenAnswer(_ => throw new MetadataFetchFailedException(0, 0, ""))

Review Comment:
   super nit: set `mapId` to `-1`



##########
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala:
##########
@@ -664,6 +667,107 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
   }
 
+  test("handle map output location change") {
+    TestUtils.withConf(
+      "spark.decommission.enabled" -> "true",
+      "spark.storage.decommission.enabled" -> "true",
+      "spark.storage.decommission.shuffleBlocks.enabled" -> "true"
+    ) {
+      val blockManager = createMockBlockManager()
+      val localBmId = blockManager.blockManagerId
+      val remoteBmId = BlockManagerId("test-remote-1", "test-remote-1", 2)
+      val blocks = Map[BlockId, ManagedBuffer](
+        ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+        ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+        ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+      )
+
+      doReturn(blocks(ShuffleBlockId(0, 2, 0))).when(blockManager)
+        .getLocalBlockData(meq(ShuffleBlockId(0, 2, 0)))
+
+      answerFetchBlocks { invocation =>
+        val host = invocation.getArgument[String](0)
+        val listener = invocation.getArgument[BlockFetchingListener](4)
+        host match {
+          case "test-remote-1" =>
+            listener.onBlockFetchSuccess(
+              ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 
0)))
+            // TODO: update exception type here
+            listener.onBlockFetchFailure(
+              ShuffleBlockId(0, 1, 0).toString, new RuntimeException())
+            listener.onBlockFetchFailure(
+              ShuffleBlockId(0, 2, 0).toString, new RuntimeException())
+          case "test-remote-2" =>
+            listener.onBlockFetchSuccess(
+              ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 
0)))
+          case "test-local-host" =>
+            listener.onBlockFetchSuccess(
+              ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 
0)))
+        }
+      }
+
+      when(mapOutputTracker.getMapOutputLocationWithRefresh(any(), any(), 
any()))
+        .thenAnswer { invocation =>
+          val mapId = invocation.getArgument[Long](1)
+          mapId match {
+            case 0 => BlockManagerId("test-remote-1", "test-remote-1", 2)
+            case 1 => BlockManagerId("test-remote-2", "test-remote-2", 2)
+            case 2 => localBmId
+          }
+        }
+
+      Seq(true, false).foreach { isEnabled =>
+        SparkEnv.get.conf.set(
+          "spark.storage.decommission.shuffleBlocks.refreshLocationsEnabled", 
isEnabled.toString)
+        val iterator = createShuffleBlockIteratorWithDefaults(
+          Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+          blockManager = Some(blockManager)
+        )
+        if (isEnabled) {
+          assert(iterator.map(_._1).toSet == blocks.keys)
+          val bytesInFlightAccess = 
PrivateMethod[Long](Symbol("bytesInFlight"))
+          assert(iterator.invokePrivate(bytesInFlightAccess()) == 0)
+        } else {
+          intercept[FetchFailedException] {
+            iterator.toList
+          }
+        }
+      }
+    }
+  }
+
+
+  test("metadata fetch failure in handle map output location change") {

Review Comment:
   Can you also add a simple test to ensure existing behavior is preserved when 
no migration has happened ?



##########
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala:
##########
@@ -264,18 +272,22 @@ final class ShuffleBlockFetcherIterator(
       case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, 
(size, mapIndex))
     }.toMap
     val remainingBlocks = new HashSet[String]() ++= infoMap.keys
-    val deferredBlocks = new ArrayBuffer[String]()
+    val deferredBlocks = new HashMap[BlockManagerId, Queue[String]]()
     val blockIds = req.blocks.map(_.blockId.toString)
     val address = req.address
     val requestStartTime = clock.nanoTime()
 
     @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
       if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
-        val blocks = deferredBlocks.map { blockId =>
-          val (size, mapIndex) = infoMap(blockId)
-          FetchBlockInfo(BlockId(blockId), size, mapIndex)
+        val newAddressToBlocks = new HashMap[BlockManagerId, 
Queue[FetchBlockInfo]]()
+        deferredBlocks.foreach { case (blockManagerId, blockIds) =>
+          val blocks = blockIds.map { blockId =>
+            val (size, mapIndex) = infoMap(blockId)
+            FetchBlockInfo(BlockId(blockId), size, mapIndex)
+          }
+          newAddressToBlocks.put(blockManagerId, blocks)
         }

Review Comment:
   Scratch this - required for decrementing accounting details.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to