Ngone51 commented on a change in pull request #30480:
URL: https://github.com/apache/spark/pull/30480#discussion_r614588672



##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -181,64 +235,141 @@ private class ShuffleStatus(numPartitions: Int) extends 
Logging {
   def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = 
withWriteLock {
     for (mapIndex <- mapStatuses.indices) {
       if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) {
-        _numAvailableOutputs -= 1
+        _numAvailableMapOutputs -= 1
         mapStatuses(mapIndex) = null
         invalidateSerializedMapOutputStatusCache()
       }
     }
+    for (reduceId <- mergeStatuses.indices) {
+      if (mergeStatuses(reduceId) != null && 
f(mergeStatuses(reduceId).location)) {
+        _numAvailableMergeResults -= 1
+        mergeStatuses(reduceId) = null
+        invalidateSerializedMergeOutputStatusCache()
+      }
+    }
+  }
+
+  /**
+   * Number of partitions that have shuffle map outputs.
+   */
+  def numAvailableMapOutputs: Int = withReadLock {
+    _numAvailableMapOutputs
   }
 
   /**
-   * Number of partitions that have shuffle outputs.
+   * Number of shuffle partitions that have already been merge finalized when 
push-based
+   * is enabled.
    */
-  def numAvailableOutputs: Int = withReadLock {
-    _numAvailableOutputs
+  def numAvailableMergeResults: Int = withReadLock {
+    _numAvailableMergeResults
   }
 
   /**
    * Returns the sequence of partition ids that are missing (i.e. needs to be 
computed).
    */
   def findMissingPartitions(): Seq[Int] = withReadLock {
     val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
-    assert(missing.size == numPartitions - _numAvailableOutputs,
-      s"${missing.size} missing, expected ${numPartitions - 
_numAvailableOutputs}")
+    assert(missing.size == numPartitions - _numAvailableMapOutputs,
+      s"${missing.size} missing, expected ${numPartitions - 
_numAvailableMapOutputs}")
     missing
   }
 
   /**
-   * Serializes the mapStatuses array into an efficient compressed format. See 
the comments on
-   * `MapOutputTracker.serializeMapStatuses()` for more details on the 
serialization format.
+   * Serializes the mapStatuses or mergeStatuses array into an efficient 
compressed format. See
+   * the comments on `MapOutputTracker.serializeOutputStatuses()` for more 
details on the
+   * serialization format.
    *
    * This method is designed to be called multiple times and implements 
caching in order to speed
    * up subsequent requests. If the cache is empty and multiple threads 
concurrently attempt to
-   * serialize the map statuses then serialization will only be performed in a 
single thread and all
-   * other threads will block until the cache is populated.
+   * serialize the statuses array then serialization will only be performed in 
a single thread and
+   * all other threads will block until the cache is populated.
    */
-  def serializedMapStatus(
+  def serializedOutputStatus(
       broadcastManager: BroadcastManager,
       isLocal: Boolean,
       minBroadcastSize: Int,
-      conf: SparkConf): Array[Byte] = {
-    var result: Array[Byte] = null
+      conf: SparkConf,
+      isMapOnlyOutput: Boolean): (Array[Byte], Array[Byte]) = {

Review comment:
       I think we can rename `isMapOnlyOutput` to `needMergeOutput` and 
simplify the code below as:
   
   ```scala
   withReadLock {
         if (cachedSerializedMapStatus != null) {
           mapStatuses = cachedSerializedMapStatus
         }
   
         if (needMergeOutput && cachedSerializedMergeStatus != null) {
           mergeStatuses = cachedSerializedMergeStatus
         }
       }
   
       if (mapStatuses == null) {
         mapStatuses =
           serializeAndCacheMapStatuses(broadcastManager, isLocal, 
minBroadcastSize, conf)
       }
       // If push based shuffle enabled, serialize and cache both Map and Merge 
Status
       if (needMergeOutput && mergeStatuses == null) {
         mergeStatuses =
           serializeAndCacheMergeStatuses(broadcastManager, isLocal, 
minBroadcastSize, conf)
       }
   ```

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -812,61 +1115,151 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, 
Int)])] = {
     logDebug(s"Fetching outputs for shuffle $shuffleId")
-    val statuses = getStatuses(shuffleId, conf)
+    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, 
conf)
     try {
-      val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length 
else endMapIndex
+      val actualEndMapIndex =
+        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else 
endMapIndex
       logDebug(s"Convert map statuses for shuffle $shuffleId, " +
         s"mappers $startMapIndex-$actualEndMapIndex, partitions 
$startPartition-$endPartition")
       MapOutputTracker.convertMapStatuses(
-        shuffleId, startPartition, endPartition, statuses, startMapIndex, 
actualEndMapIndex)
+        shuffleId, startPartition, endPartition, mapOutputStatuses, 
startMapIndex,
+          actualEndMapIndex, Option(mergedOutputStatuses))
     } catch {
       case e: MetadataFetchFailedException =>
         // We experienced a fetch failure so our mapStatuses cache is 
outdated; clear it:
         mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] 
= {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition 
$partitionId")
+    // Fetch the map statuses and merge statuses again since they might have 
already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf)
+    try {
+      val mergeStatus = mergeResultStatuses(partitionId)
+      // If the original MergeStatus is no longer available, we cannot 
identify the list of
+      // unmerged blocks to fetch in this case. Throw 
MetadataFetchFailedException in this case.
+      MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId)
+      // Use the MergeStatus's partition level bitmap since we are doing 
partition level fallback
+      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId,
+        mapOutputStatuses, mergeStatus.tracker)
+    } catch {
+      // We experienced a fetch failure so our mapStatuses cache is outdated; 
clear it
+      case e: MetadataFetchFailedException =>
+        mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int,
+      chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, 
Long, Int)])] = {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition 
$partitionId")
+    // Fetch the map statuses and merge statuses again since they might have 
already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, _) = getStatuses(shuffleId, conf)
+    try {
+      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, 
mapOutputStatuses,
+        chunkTracker)
+    } catch {
+      // We experienced a fetch failure so our mapStatuses cache is outdated; 
clear it:
+      case e: MetadataFetchFailedException =>
+        mapStatuses.clear()
+        mergeStatuses.clear()
         throw e
     }
   }
 
   /**
-   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: 
clients MUST synchronize
+   * Get or fetch the array of MapStatuses and MergeStatuses if push based 
shuffle enabled
+   * for a given shuffle ID. NOTE: clients MUST synchronize
    * on this array when reading it, because on the driver, we may be changing 
it in place.
    *
    * (It would be nice to remove this restriction in the future.)
    */
-  private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = 
{
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching 
them")
-      val startTimeNs = System.nanoTime()
-      fetchingLock.withLock(shuffleId) {
-        var fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
-          val fetchedBytes = 
askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          try {
-            fetchedStatuses = 
MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
-          } catch {
-            case e: SparkException =>
-              throw new MetadataFetchFailedException(shuffleId, -1,
-                s"Unable to deserialize broadcasted map statuses for shuffle 
$shuffleId: " +
-                  e.getCause)
+  private def getStatuses(
+      shuffleId: Int,
+      conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = {
+    if (fetchMergeResult) {
+      val mapOutputStatuses = mapStatuses.get(shuffleId).orNull
+      val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull
+
+      if (mapOutputStatuses == null || mergeOutputStatuses == null) {
+        logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", 
fetching them")
+        val startTimeNs = System.nanoTime()
+        fetchingLock.withLock(shuffleId) {
+          var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull
+          var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull
+          if (fetchedMapStatuses == null || fetchedMergeStatuses == null) {
+            logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+            val fetchedBytes =
+              askTracker[(Array[Byte], 
Array[Byte])](GetMapAndMergeResultStatuses(shuffleId))

Review comment:
       I may miss some discussion after my last discussion, I think this 
breaches our decision made before:
   
   we won't affect the existing code path in the case of map status only.
   
   
   I think you can return the mapstatus only at the sender side to keep the 
same behavior?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1000,18 +1403,55 @@ private[spark] object MapOutputTracker extends Logging {
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
-      statuses: Array[MapStatus],
+      mapStatuses: Array[MapStatus],
       startMapIndex : Int,
-      endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] 
= {
-    assert (statuses != null)
+      endMapIndex: Int,
+      mergeStatuses: Option[Array[MergeStatus]] = None):
+      Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, 
Long, Int)]]
-    val iter = statuses.iterator.zipWithIndex
-    for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
-      if (status == null) {
-        val errorMessage = s"Missing an output location for shuffle $shuffleId"
-        logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, startPartition, 
errorMessage)
-      } else {
+    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since 
a merged shuffle
+    // partition consists of blocks merged in random order, we are unable to 
serve map index
+    // subrange requests. However, when a reduce task needs to fetch blocks 
from a subrange of
+    // map outputs, it usually indicates skewed partitions which push-based 
shuffle delegates
+    // to AQE to handle.
+    // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with 
Push based shuffle,
+    // TODO: improve push based shuffle to read partial merged blocks 
satisfying the start/end
+    // TODO: map indexes
+    if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == 
mapStatuses.length) {
+      // We have MergeStatus and full range of mapIds are requested so return 
a merged block.
+      val numMaps = mapStatuses.length
+      mergeStatuses.get.zipWithIndex.slice(startPartition, 
endPartition).foreach {
+        case (mergeStatus, partId) =>
+          val remainingMapStatuses = if (mergeStatus != null && 
mergeStatus.totalSize > 0) {
+            // If MergeStatus is available for the given partition, add 
location of the
+            // pre-merged shuffle partition for this partition ID. Here we 
create a
+            // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate 
this is
+            // a merged shuffle block.
+            splitsByAddress.getOrElseUpdate(mergeStatus.location, 
ListBuffer()) +=
+              ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), 
mergeStatus.totalSize, -1))
+            // For the "holes" in this pre-merged shuffle partition, i.e., 
unmerged mapper
+            // shuffle partition blocks, fetch the original map produced 
shuffle partition blocks
+            mergeStatus.getMissingMaps(numMaps).map(mapStatuses.zipWithIndex)

Review comment:
       `mapStatuses.zipWithIndex` would be called for multiple times?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -49,7 +50,10 @@ import org.apache.spark.util._
  *
  * All public methods of this class are thread-safe.
  */
-private class ShuffleStatus(numPartitions: Int) extends Logging {
+private class ShuffleStatus(
+    numPartitions: Int,
+    numReducers: Int,

Review comment:
       I'm thinking we could use `numReducers = -1` to indicate the disabling. 
Thus we don't need `isPushBasedShuffleEnabled`. But maybe a little bit tricky. 
It's up to you.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -633,23 +887,50 @@ private[spark] class MapOutputTrackerMaster(
 
   /**
    * Return the preferred hosts on which to run the given map output partition 
in a given shuffle,
-   * i.e. the nodes that the most outputs for that partition are on.
+   * i.e. the nodes that the most outputs for that partition are on. If the 
map output is
+   * pre-merged, then return the node where the merged block is located if the 
merge ratio is
+   * above the threshold.
    *
    * @param dep shuffle dependency object
    * @param partitionId map output partition that we want to read
    * @return a sequence of host names
    */
   def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], 
partitionId: Int)
       : Seq[String] = {
-    if (shuffleLocalityEnabled && dep.rdd.partitions.length < 
SHUFFLE_PREF_MAP_THRESHOLD &&
-        dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
-      val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, 
partitionId,
-        dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
-      if (blockManagerIds.nonEmpty) {
-        blockManagerIds.get.map(_.host)
+    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+    if (shuffleStatus != null) {
+      // Check if the map output is pre-merged and if the merge ratio is above 
the threshold.
+      // If so, the location of the merged block is the preferred location.
+      val preferredLoc = if (pushBasedShuffleEnabled) {

Review comment:
       Doesn't this path need to respect `shuffleLocalityEnabled` too?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1000,18 +1403,55 @@ private[spark] object MapOutputTracker extends Logging {
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
-      statuses: Array[MapStatus],
+      mapStatuses: Array[MapStatus],
       startMapIndex : Int,
-      endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] 
= {
-    assert (statuses != null)
+      endMapIndex: Int,
+      mergeStatuses: Option[Array[MergeStatus]] = None):
+      Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, 
Long, Int)]]
-    val iter = statuses.iterator.zipWithIndex
-    for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
-      if (status == null) {
-        val errorMessage = s"Missing an output location for shuffle $shuffleId"
-        logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, startPartition, 
errorMessage)
-      } else {
+    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since 
a merged shuffle
+    // partition consists of blocks merged in random order, we are unable to 
serve map index
+    // subrange requests. However, when a reduce task needs to fetch blocks 
from a subrange of
+    // map outputs, it usually indicates skewed partitions which push-based 
shuffle delegates
+    // to AQE to handle.
+    // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with 
Push based shuffle,
+    // TODO: improve push based shuffle to read partial merged blocks 
satisfying the start/end
+    // TODO: map indexes
+    if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == 
mapStatuses.length) {

Review comment:
       nit: `mergeStatuses.exists(_.nonEmpty)` ?
   
   We can skip too if the merged status is empty.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -812,61 +1115,151 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, 
Int)])] = {
     logDebug(s"Fetching outputs for shuffle $shuffleId")
-    val statuses = getStatuses(shuffleId, conf)
+    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, 
conf)
     try {
-      val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length 
else endMapIndex
+      val actualEndMapIndex =
+        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else 
endMapIndex
       logDebug(s"Convert map statuses for shuffle $shuffleId, " +
         s"mappers $startMapIndex-$actualEndMapIndex, partitions 
$startPartition-$endPartition")
       MapOutputTracker.convertMapStatuses(
-        shuffleId, startPartition, endPartition, statuses, startMapIndex, 
actualEndMapIndex)
+        shuffleId, startPartition, endPartition, mapOutputStatuses, 
startMapIndex,
+          actualEndMapIndex, Option(mergedOutputStatuses))
     } catch {
       case e: MetadataFetchFailedException =>
         // We experienced a fetch failure so our mapStatuses cache is 
outdated; clear it:
         mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(

Review comment:
       comment "test only"?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1000,18 +1403,55 @@ private[spark] object MapOutputTracker extends Logging {
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
-      statuses: Array[MapStatus],
+      mapStatuses: Array[MapStatus],
       startMapIndex : Int,
-      endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] 
= {
-    assert (statuses != null)
+      endMapIndex: Int,
+      mergeStatuses: Option[Array[MergeStatus]] = None):
+      Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, 
Long, Int)]]
-    val iter = statuses.iterator.zipWithIndex
-    for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
-      if (status == null) {
-        val errorMessage = s"Missing an output location for shuffle $shuffleId"
-        logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, startPartition, 
errorMessage)
-      } else {
+    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since 
a merged shuffle
+    // partition consists of blocks merged in random order, we are unable to 
serve map index
+    // subrange requests. However, when a reduce task needs to fetch blocks 
from a subrange of
+    // map outputs, it usually indicates skewed partitions which push-based 
shuffle delegates
+    // to AQE to handle.
+    // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with 
Push based shuffle,
+    // TODO: improve push based shuffle to read partial merged blocks 
satisfying the start/end
+    // TODO: map indexes
+    if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == 
mapStatuses.length) {
+      // We have MergeStatus and full range of mapIds are requested so return 
a merged block.
+      val numMaps = mapStatuses.length
+      mergeStatuses.get.zipWithIndex.slice(startPartition, 
endPartition).foreach {
+        case (mergeStatus, partId) =>
+          val remainingMapStatuses = if (mergeStatus != null && 
mergeStatus.totalSize > 0) {

Review comment:
       I remember Magnet declares that it's able to fall back to the original 
fetch (using mapstatus) when fetch failure happens. But, here, it looks like we 
only collect the merged status for those maps only without backup mapstatuses. 
(Because in my mind, I think we can collect both merged statues and original 
mapstatus together so that we can fall back if need). How do we plan to support 
the fallback?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -812,61 +1115,151 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, 
Int)])] = {
     logDebug(s"Fetching outputs for shuffle $shuffleId")
-    val statuses = getStatuses(shuffleId, conf)
+    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, 
conf)
     try {
-      val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length 
else endMapIndex
+      val actualEndMapIndex =
+        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else 
endMapIndex
       logDebug(s"Convert map statuses for shuffle $shuffleId, " +
         s"mappers $startMapIndex-$actualEndMapIndex, partitions 
$startPartition-$endPartition")
       MapOutputTracker.convertMapStatuses(
-        shuffleId, startPartition, endPartition, statuses, startMapIndex, 
actualEndMapIndex)
+        shuffleId, startPartition, endPartition, mapOutputStatuses, 
startMapIndex,
+          actualEndMapIndex, Option(mergedOutputStatuses))
     } catch {
       case e: MetadataFetchFailedException =>
         // We experienced a fetch failure so our mapStatuses cache is 
outdated; clear it:
         mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] 
= {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition 
$partitionId")
+    // Fetch the map statuses and merge statuses again since they might have 
already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf)
+    try {
+      val mergeStatus = mergeResultStatuses(partitionId)
+      // If the original MergeStatus is no longer available, we cannot 
identify the list of
+      // unmerged blocks to fetch in this case. Throw 
MetadataFetchFailedException in this case.
+      MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId)
+      // Use the MergeStatus's partition level bitmap since we are doing 
partition level fallback
+      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId,
+        mapOutputStatuses, mergeStatus.tracker)
+    } catch {
+      // We experienced a fetch failure so our mapStatuses cache is outdated; 
clear it
+      case e: MetadataFetchFailedException =>
+        mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(

Review comment:
       comment "test only"?
   
   

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -633,23 +887,50 @@ private[spark] class MapOutputTrackerMaster(
 
   /**
    * Return the preferred hosts on which to run the given map output partition 
in a given shuffle,
-   * i.e. the nodes that the most outputs for that partition are on.
+   * i.e. the nodes that the most outputs for that partition are on. If the 
map output is
+   * pre-merged, then return the node where the merged block is located if the 
merge ratio is
+   * above the threshold.
    *
    * @param dep shuffle dependency object
    * @param partitionId map output partition that we want to read
    * @return a sequence of host names
    */
   def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], 
partitionId: Int)
       : Seq[String] = {
-    if (shuffleLocalityEnabled && dep.rdd.partitions.length < 
SHUFFLE_PREF_MAP_THRESHOLD &&
-        dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
-      val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, 
partitionId,
-        dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
-      if (blockManagerIds.nonEmpty) {
-        blockManagerIds.get.map(_.host)
+    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+    if (shuffleStatus != null) {
+      // Check if the map output is pre-merged and if the merge ratio is above 
the threshold.
+      // If so, the location of the merged block is the preferred location.
+      val preferredLoc = if (pushBasedShuffleEnabled) {
+        shuffleStatus.withMergeStatuses { statuses =>
+          val status = statuses(partitionId)
+          val numMaps = dep.rdd.partitions.length
+          if (status != null && 
status.getNumMissingMapOutputs(numMaps).toDouble / numMaps
+            <= (1 - REDUCER_PREF_LOCS_FRACTION)) {
+            Seq(status.location.host)
+          } else {
+            Nil
+          }
+        }
       } else {
         Nil
       }
+      if (!preferredLoc.isEmpty) {

Review comment:
       nit: `preferredLoc.nonEmpty`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to