This is an automated email from the ASF dual-hosted git repository.

wuyi pushed a commit to branch SPARK-48394-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git

commit c9d94ef8e7c7d35e3f2995ffb63596a993a766c8
Author: Yi Wu <[email protected]>
AuthorDate: Fri May 24 16:01:17 2024 -0700

    [SPARK-48394][CORE] Cleanup mapIdToMapIndex on mapoutput unregister
    
    This PR cleans up `mapIdToMapIndex` when the corresponding mapstatus is 
unregistered in three places:
    * `removeMapOutput`
    * `removeOutputsByFilter`
    * `addMapOutput` (old mapstatus overwritten)
    
    There is only one valid mapstatus for the same `mapIndex` at the same time 
in Spark. `mapIdToMapIndex` should also follows the same rule to avoid chaos.
    
    No.
    
    Unit tests.
    
    No.
    
    Closes #46706 from Ngone51/SPARK-43043-followup.
    
    Lead-authored-by: Yi Wu <[email protected]>
    Co-authored-by: wuyi <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../scala/org/apache/spark/MapOutputTracker.scala  | 26 ++++++----
 .../org/apache/spark/MapOutputTrackerSuite.scala   | 55 ++++++++++++++++++++++
 2 files changed, 72 insertions(+), 9 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 3495536a3508..9a7a3b0c0e75 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -42,7 +42,6 @@ import org.apache.spark.scheduler.{MapStatus, MergeStatus, 
ShuffleOutputStatus}
 import org.apache.spark.shuffle.MetadataFetchFailedException
 import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, 
ShuffleMergedBlockId}
 import org.apache.spark.util._
-import org.apache.spark.util.collection.OpenHashMap
 import org.apache.spark.util.io.{ChunkedByteBuffer, 
ChunkedByteBufferOutputStream}
 
 /**
@@ -151,17 +150,22 @@ private class ShuffleStatus(
   /**
    * Mapping from a mapId to the mapIndex, this is required to reduce the 
searching overhead within
    * the function updateMapOutput(mapId, bmAddress).
+   *
+   * Exposed for testing.
    */
-  private[this] val mapIdToMapIndex = new OpenHashMap[Long, Int]()
+  private[spark] val mapIdToMapIndex = new HashMap[Long, Int]()
 
   /**
    * Register a map output. If there is already a registered location for the 
map output then it
    * will be replaced by the new location.
    */
   def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
-    if (mapStatuses(mapIndex) == null) {
+    val currentMapStatus = mapStatuses(mapIndex)
+    if (currentMapStatus == null) {
       _numAvailableMapOutputs += 1
       invalidateSerializedMapOutputStatusCache()
+    } else {
+      mapIdToMapIndex.remove(currentMapStatus.mapId)
     }
     mapStatuses(mapIndex) = status
     mapIdToMapIndex(status.mapId) = mapIndex
@@ -190,8 +194,8 @@ private class ShuffleStatus(
           mapStatus.updateLocation(bmAddress)
           invalidateSerializedMapOutputStatusCache()
         case None =>
-          if (mapIndex.map(mapStatusesDeleted).exists(_.mapId == mapId)) {
-            val index = mapIndex.get
+          val index = mapStatusesDeleted.indexWhere(x => x != null && x.mapId 
== mapId)
+          if (index >= 0 && mapStatuses(index) == null) {
             val mapStatus = mapStatusesDeleted(index)
             mapStatus.updateLocation(bmAddress)
             mapStatuses(index) = mapStatus
@@ -216,9 +220,11 @@ private class ShuffleStatus(
    */
   def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = 
withWriteLock {
     logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
-    if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == 
bmAddress) {
+    val currentMapStatus = mapStatuses(mapIndex)
+    if (currentMapStatus != null && currentMapStatus.location == bmAddress) {
       _numAvailableMapOutputs -= 1
-      mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex)
+      mapIdToMapIndex.remove(currentMapStatus.mapId)
+      mapStatusesDeleted(mapIndex) = currentMapStatus
       mapStatuses(mapIndex) = null
       invalidateSerializedMapOutputStatusCache()
     }
@@ -284,9 +290,11 @@ private class ShuffleStatus(
    */
   def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = 
withWriteLock {
     for (mapIndex <- mapStatuses.indices) {
-      if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) {
+      val currentMapStatus = mapStatuses(mapIndex)
+      if (currentMapStatus != null && f(currentMapStatus.location)) {
         _numAvailableMapOutputs -= 1
-        mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex)
+        mapIdToMapIndex.remove(currentMapStatus.mapId)
+        mapStatusesDeleted(mapIndex) = currentMapStatus
         mapStatuses(mapIndex) = null
         invalidateSerializedMapOutputStatusCache()
       }
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala 
b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 450ff01921a8..d6f925ddced9 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -1109,4 +1109,59 @@ class MapOutputTrackerSuite extends SparkFunSuite with 
LocalSparkContext {
       rpcEnv.shutdown()
     }
   }
+
+  test(
+    "SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after 
removeOutputsByFilter"
+  ) {
+    val rpcEnv = createRpcEnv("test")
+    val tracker = newTrackerMaster()
+    try {
+      tracker.trackerEndpoint = 
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+        new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+      tracker.registerShuffle(0, 1, 1)
+      tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", 
"hostA", 1000),
+        Array(2L), 0))
+      tracker.removeOutputsOnHost("hostA")
+      assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size 
== 0)
+    } finally {
+      tracker.stop()
+      rpcEnv.shutdown()
+    }
+  }
+
+  test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after 
unregisterMapOutput") {
+    val rpcEnv = createRpcEnv("test")
+    val tracker = newTrackerMaster()
+    try {
+      tracker.trackerEndpoint = 
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+        new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+      tracker.registerShuffle(0, 1, 1)
+      tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", 
"hostA", 1000),
+        Array(2L), 0))
+      tracker.unregisterMapOutput(0, 0, BlockManagerId("exec-1", "hostA", 
1000))
+      assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size 
== 0)
+    } finally {
+      tracker.stop()
+      rpcEnv.shutdown()
+    }
+  }
+
+  test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after 
registerMapOutput") {
+    val rpcEnv = createRpcEnv("test")
+    val tracker = newTrackerMaster()
+    try {
+      tracker.trackerEndpoint = 
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+        new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+      tracker.registerShuffle(0, 1, 1)
+      tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", 
"hostA", 1000),
+        Array(2L), 0))
+      // Another task also finished working on partition 0.
+      tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-2", 
"hostB", 1000),
+        Array(2L), 1))
+      assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size 
== 1)
+    } finally {
+      tracker.stop()
+      rpcEnv.shutdown()
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to