mridulm commented on a change in pull request #30480:
URL: https://github.com/apache/spark/pull/30480#discussion_r532974428



##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -86,36 +87,58 @@ private class ShuffleStatus(numPartitions: Int) extends 
Logging {
   // Exposed for testing
   val mapStatuses = new Array[MapStatus](numPartitions)
 
+  /**
+   * MergeStatus for each shuffle partition when push-based shuffle is 
enabled. The index of the
+   * array is the shuffle partition id (reduce id). Each value in the array is 
the MergeStatus for
+   * a shuffle partition, or null if not available. When push-based shuffle is 
enabled, this array
+   * provides a reducer oriented view of the shuffle status specifically for 
the results of
+   * merging shuffle partition blocks into per-partition merged shuffle files.
+   */
+  val mergeStatuses = new Array[MergeStatus](numReducers)

Review comment:
       Do this only if push based shuffle is enabled.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -367,6 +480,28 @@ private[spark] abstract class MapOutputTracker(conf: 
SparkConf) extends Logging
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
 
+  /**
+   * Called from executors upon fetch failure on an entire merged shuffle 
partition. Such failures
+   * can happen if the shuffle client fails to fetch the metadata for the 
given merged shuffle
+   * partition. This method is to get the server URIs and output sizes for 
each shuffle block that
+   * is merged in the specified merged shuffle block so fetch failure on a 
merged shuffle block can
+   * fall back to fetching the unmerged blocks.
+   */

Review comment:
       Add `@return` with details for both methods.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -149,12 +172,36 @@ private class ShuffleStatus(numPartitions: Int) extends 
Logging {
   def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = 
withWriteLock {
     logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
     if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == 
bmAddress) {
-      _numAvailableOutputs -= 1
+      _numAvailableMapOutputs -= 1
       mapStatuses(mapIndex) = null
       invalidateSerializedMapOutputStatusCache()
     }
   }
 
+  /**
+   * Register a merge result.
+   */
+  def addMergeResult(reduceId: Int, status: MergeStatus): Unit = withWriteLock 
{
+    if (mergeStatuses(reduceId) == null) {
+      _numAvailableMergeResults += 1
+      invalidateSerializedMergeOutputStatusCache()
+    }
+    mergeStatuses(reduceId) = status

Review comment:
       `invalidateSerializedMergeOutputStatusCache` should be done if 
`mergeStatuses(reduceId) != status`.
   

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -633,23 +810,50 @@ private[spark] class MapOutputTrackerMaster(
 
   /**
    * Return the preferred hosts on which to run the given map output partition 
in a given shuffle,
-   * i.e. the nodes that the most outputs for that partition are on.
+   * i.e. the nodes that the most outputs for that partition are on. If the 
map output is
+   * pre-merged, then return the node where the merged block is located if the 
merge ratio is
+   * above the threshold.
    *
    * @param dep shuffle dependency object
    * @param partitionId map output partition that we want to read
    * @return a sequence of host names
    */
   def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], 
partitionId: Int)
       : Seq[String] = {
-    if (shuffleLocalityEnabled && dep.rdd.partitions.length < 
SHUFFLE_PREF_MAP_THRESHOLD &&
-        dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
-      val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, 
partitionId,
-        dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
-      if (blockManagerIds.nonEmpty) {
-        blockManagerIds.get.map(_.host)
+    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+    if (shuffleStatus != null) {
+      // Check if the map output is pre-merged and if the merge ratio is above 
the threshold.
+      // If so, the location of the merged block is the preferred location.
+      val preferredLoc = if (pushBasedShuffleEnabled) {
+        shuffleStatus.withMergeStatuses { statuses =>
+          val status = statuses(partitionId)
+          val numMaps = dep.rdd.partitions.length
+          if (status != null && status.getMissingMaps(numMaps).length.toDouble 
/ numMaps
+            <= (1 - REDUCER_PREF_LOCS_FRACTION)) {
+            Seq(status.location.host)
+          } else {
+            Nil
+          }
+        }
       } else {
         Nil
       }
+      if (!preferredLoc.isEmpty) {
+        preferredLoc
+      } else {
+        if (shuffleLocalityEnabled && dep.rdd.partitions.length < 
SHUFFLE_PREF_MAP_THRESHOLD &&
+          dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
+          val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, 
partitionId,
+            dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
+          if (blockManagerIds.nonEmpty) {
+            blockManagerIds.get.map(_.host)
+          } else {
+            Nil
+          }
+        } else {
+          Nil
+        }
+      }

Review comment:
       Note: With data skew, it is possible that merged output is smaller in 
size than what is computed from `shuffleLocalityEnabled` case - particularly 
given these mappers could be running on the same host.
   Practically, this is unlikely.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -812,17 +1038,64 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, 
Int)])] = {
     logDebug(s"Fetching outputs for shuffle $shuffleId")
-    val statuses = getStatuses(shuffleId, conf)
+    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, 
conf)
     try {
-      val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length 
else endMapIndex
+      val actualEndMapIndex =
+        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else 
endMapIndex
       logDebug(s"Convert map statuses for shuffle $shuffleId, " +
         s"mappers $startMapIndex-$actualEndMapIndex, partitions 
$startPartition-$endPartition")
       MapOutputTracker.convertMapStatuses(
-        shuffleId, startPartition, endPartition, statuses, startMapIndex, 
actualEndMapIndex)
+        shuffleId, startPartition, endPartition, mapOutputStatuses, 
startMapIndex,
+          actualEndMapIndex, Option(mergedOutputStatuses))
     } catch {
       case e: MetadataFetchFailedException =>
         // We experienced a fetch failure so our mapStatuses cache is 
outdated; clear it:
         mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] 
= {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition 
$partitionId")
+    // Fetch the map statuses and merge statuses again since they might have 
already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf)
+    try {
+      val mergeStatus = mergeResultStatuses(partitionId)
+      // If the original MergeStatus is no longer available, we cannot 
identify the list of
+      // unmerged blocks to fetch in this case. Throw 
MetadataFetchFailedException in this case.
+      MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId)
+      // User the MergeStatus's partition level bitmap since we are doing 
partition level fallback
+      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId,
+        mapOutputStatuses, mergeStatus.tracker)
+    } catch {
+      // We experienced a fetch failure so our mapStatuses cache is outdated; 
clear it:

Review comment:
       nit: remove suffix `:` in comment

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -833,33 +1106,44 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
    *
    * (It would be nice to remove this restriction in the future.)
    */
-  private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = 
{
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching 
them")
+  private def getStatuses(
+      shuffleId: Int, conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) 
= {
+    val mapOutputStatuses = mapStatuses.get(shuffleId).orNull
+    val mergeResultStatuses = mergeStatuses.get(shuffleId).orNull
+    if (mapOutputStatuses == null || (fetchMergeResult && mergeResultStatuses 
== null)) {
+      logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", 
fetching them")
       val startTimeNs = System.nanoTime()
       fetchingLock.withLock(shuffleId) {
-        var fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+        var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull
+        if (fetchedMapStatuses == null) {
+          logInfo("Doing the map fetch; tracker endpoint = " + trackerEndpoint)
           val fetchedBytes = 
askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          fetchedStatuses = 
MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
-          logInfo("Got the output locations")
-          mapStatuses.put(shuffleId, fetchedStatuses)
+          fetchedMapStatuses = 
MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf)
+          logInfo("Got the map output locations")
+          mapStatuses.put(shuffleId, fetchedMapStatuses)
         }
-        logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
+        var fetchedMergeStatues = mergeStatuses.get(shuffleId).orNull
+        if (fetchMergeResult && fetchedMergeStatues == null) {
+          logInfo("Doing the merge fetch; tracker endpoint = " + 
trackerEndpoint)
+          val fetchedBytes = 
askTracker[Array[Byte]](GetMergeResultStatuses(shuffleId))
+          fetchedMergeStatues = 
MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf)
+          logInfo("Got the merge output locations")
+          mergeStatuses.put(shuffleId, fetchedMergeStatues)
+        }
+        logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId 
took " +
           s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} 
ms")
-        fetchedStatuses
+        (fetchedMapStatuses, fetchedMergeStatues)
       }

Review comment:
       If `fetchMergeResult == true`, is it right that there is an expectation 
that `(mapOutputStatuses == null) == (mergeResultStatuses == null)` ?
   
   If yes, can we simplify this ?
   a) Make this method simpler by using that condition.
   b) Do we have any usecase for `GetMergeResultStatuses` withough also 
fetching  `GetMapOutputStatuses` immediately before ? If not, combine both to 
avoid two rpc's when `fetchMergeResult == true` ?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -475,15 +620,15 @@ private[spark] class MapOutputTrackerMaster(
   }
 
   /** A poison endpoint that indicates MessageLoop should exit its message 
loop. */
-  private val PoisonPill = new GetMapOutputMessage(-99, null)
+  private val PoisonPill = new GetOutputStatusesMessage(-99, true, null)
 
   // Used only in unit tests.
   private[spark] def getNumCachedSerializedBroadcast: Int = {
     shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
   }
 
-  def registerShuffle(shuffleId: Int, numMaps: Int): Unit = {
-    if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
+  def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int = 0): Unit 
= {

Review comment:
       Remove the default value and require callers to specify reducers.
   0 as default value does not make sense for shuffle.
   

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -633,23 +810,50 @@ private[spark] class MapOutputTrackerMaster(
 
   /**
    * Return the preferred hosts on which to run the given map output partition 
in a given shuffle,
-   * i.e. the nodes that the most outputs for that partition are on.
+   * i.e. the nodes that the most outputs for that partition are on. If the 
map output is
+   * pre-merged, then return the node where the merged block is located if the 
merge ratio is
+   * above the threshold.
    *
    * @param dep shuffle dependency object
    * @param partitionId map output partition that we want to read
    * @return a sequence of host names
    */
   def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], 
partitionId: Int)
       : Seq[String] = {
-    if (shuffleLocalityEnabled && dep.rdd.partitions.length < 
SHUFFLE_PREF_MAP_THRESHOLD &&
-        dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
-      val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, 
partitionId,
-        dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
-      if (blockManagerIds.nonEmpty) {
-        blockManagerIds.get.map(_.host)
+    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+    if (shuffleStatus != null) {
+      // Check if the map output is pre-merged and if the merge ratio is above 
the threshold.
+      // If so, the location of the merged block is the preferred location.
+      val preferredLoc = if (pushBasedShuffleEnabled) {
+        shuffleStatus.withMergeStatuses { statuses =>
+          val status = statuses(partitionId)
+          val numMaps = dep.rdd.partitions.length
+          if (status != null && status.getMissingMaps(numMaps).length.toDouble 
/ numMaps

Review comment:
       nit: Add a `getNumMissingMapOutputs`.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -449,21 +586,29 @@ private[spark] class MapOutputTrackerMaster(
       try {
         while (true) {
           try {
-            val data = mapOutputRequests.take()
-             if (data == PoisonPill) {
+            val data = outputStatusesRequests.take()
+            if (data == PoisonPill) {
               // Put PoisonPill back so that other MessageLoops can see it.
-              mapOutputRequests.offer(PoisonPill)
+              outputStatusesRequests.offer(PoisonPill)
               return
             }
             val context = data.context
             val shuffleId = data.shuffleId
             val hostPort = context.senderAddress.hostPort
-            logDebug("Handling request to send map output locations for 
shuffle " + shuffleId +
-              " to " + hostPort)
             val shuffleStatus = shuffleStatuses.get(shuffleId).head
-            context.reply(
-              shuffleStatus.serializedMapStatus(broadcastManager, isLocal, 
minSizeForBroadcast,
-                conf))
+            if (data.fetchMapOutput) {
+              logDebug("Handling request to send map output locations for 
shuffle " + shuffleId +
+                " to " + hostPort)
+              context.reply(
+                shuffleStatus.serializedOutputStatus(broadcastManager, 
isLocal, minSizeForBroadcast,
+                  conf, isMapOutput = true))
+            } else {
+              logDebug("Handling request to send merge output locations for 
shuffle " + shuffleId +
+                " to " + hostPort)
+              context.reply(
+                shuffleStatus.serializedOutputStatus(broadcastManager, 
isLocal, minSizeForBroadcast,
+                  conf, isMapOutput = false))

Review comment:
       nit: unify `context.reply` with `isMapOutput = data.fetchMapOutput`

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -633,23 +810,50 @@ private[spark] class MapOutputTrackerMaster(
 
   /**
    * Return the preferred hosts on which to run the given map output partition 
in a given shuffle,
-   * i.e. the nodes that the most outputs for that partition are on.
+   * i.e. the nodes that the most outputs for that partition are on. If the 
map output is
+   * pre-merged, then return the node where the merged block is located if the 
merge ratio is
+   * above the threshold.
    *
    * @param dep shuffle dependency object
    * @param partitionId map output partition that we want to read
    * @return a sequence of host names
    */
   def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], 
partitionId: Int)
       : Seq[String] = {
-    if (shuffleLocalityEnabled && dep.rdd.partitions.length < 
SHUFFLE_PREF_MAP_THRESHOLD &&
-        dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
-      val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, 
partitionId,
-        dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
-      if (blockManagerIds.nonEmpty) {
-        blockManagerIds.get.map(_.host)
+    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+    if (shuffleStatus != null) {
+      // Check if the map output is pre-merged and if the merge ratio is above 
the threshold.
+      // If so, the location of the merged block is the preferred location.
+      val preferredLoc = if (pushBasedShuffleEnabled) {
+        shuffleStatus.withMergeStatuses { statuses =>
+          val status = statuses(partitionId)
+          val numMaps = dep.rdd.partitions.length
+          if (status != null && status.getMissingMaps(numMaps).length.toDouble 
/ numMaps
+            <= (1 - REDUCER_PREF_LOCS_FRACTION)) {
+            Seq(status.location.host)

Review comment:
       Some of the missing maps can be colocated on the same node - if blocks 
were not pushed due to map output for reducer being large. So this is a more 
conservative estimate for locality preference.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -367,6 +480,28 @@ private[spark] abstract class MapOutputTracker(conf: 
SparkConf) extends Logging
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
 
+  /**
+   * Called from executors upon fetch failure on an entire merged shuffle 
partition. Such failures
+   * can happen if the shuffle client fails to fetch the metadata for the 
given merged shuffle
+   * partition. This method is to get the server URIs and output sizes for 
each shuffle block that
+   * is merged in the specified merged shuffle block so fetch failure on a 
merged shuffle block can
+   * fall back to fetching the unmerged blocks.
+   */
+  def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
+
+  /**
+   * Called from executors upon fetch failure on a merged shuffle partition 
chunk. This is to get
+   * the server URIs and output sizes for each shuffle block that is merged in 
the specified merged
+   * shuffle partition chunk so fetch failure on a merged shuffle block chunk 
can fall back to
+   * fetching the unmerged blocks.
+   */

Review comment:
       Add a note about `chunkBitmap`

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -812,17 +1038,64 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, 
Int)])] = {
     logDebug(s"Fetching outputs for shuffle $shuffleId")
-    val statuses = getStatuses(shuffleId, conf)
+    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, 
conf)
     try {
-      val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length 
else endMapIndex
+      val actualEndMapIndex =
+        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else 
endMapIndex
       logDebug(s"Convert map statuses for shuffle $shuffleId, " +
         s"mappers $startMapIndex-$actualEndMapIndex, partitions 
$startPartition-$endPartition")
       MapOutputTracker.convertMapStatuses(
-        shuffleId, startPartition, endPartition, statuses, startMapIndex, 
actualEndMapIndex)
+        shuffleId, startPartition, endPartition, mapOutputStatuses, 
startMapIndex,
+          actualEndMapIndex, Option(mergedOutputStatuses))
     } catch {
       case e: MetadataFetchFailedException =>
         // We experienced a fetch failure so our mapStatuses cache is 
outdated; clear it:
         mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] 
= {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition 
$partitionId")
+    // Fetch the map statuses and merge statuses again since they might have 
already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf)
+    try {
+      val mergeStatus = mergeResultStatuses(partitionId)
+      // If the original MergeStatus is no longer available, we cannot 
identify the list of
+      // unmerged blocks to fetch in this case. Throw 
MetadataFetchFailedException in this case.
+      MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId)
+      // User the MergeStatus's partition level bitmap since we are doing 
partition level fallback

Review comment:
       nit: `User the` -> `Use`

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -987,18 +1277,51 @@ private[spark] object MapOutputTracker extends Logging {
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
-      statuses: Array[MapStatus],
+      mapStatuses: Array[MapStatus],
       startMapIndex : Int,
-      endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] 
= {
-    assert (statuses != null)
+      endMapIndex: Int,
+      mergeStatuses: Option[Array[MergeStatus]] = None):
+      Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, 
Long, Int)]]
-    val iter = statuses.iterator.zipWithIndex
-    for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
-      if (status == null) {
-        val errorMessage = s"Missing an output location for shuffle $shuffleId"
-        logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, startPartition, 
errorMessage)
-      } else {
+    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since 
a merged shuffle
+    // partition consists of blocks merged in random order, we are unable to 
serve map index
+    // subrange requests. However, when a reduce task needs to fetch blocks 
from a subrange of
+    // map outputs, it usually indicates skewed partitions which push-based 
shuffle delegates
+    // to AQE to handle.

Review comment:
       Can we rephrase this as a `TODO` comment ? This will be something we 
should support in future.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -987,18 +1277,51 @@ private[spark] object MapOutputTracker extends Logging {
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
-      statuses: Array[MapStatus],
+      mapStatuses: Array[MapStatus],
       startMapIndex : Int,
-      endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] 
= {
-    assert (statuses != null)
+      endMapIndex: Int,
+      mergeStatuses: Option[Array[MergeStatus]] = None):
+      Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, 
Long, Int)]]
-    val iter = statuses.iterator.zipWithIndex
-    for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
-      if (status == null) {
-        val errorMessage = s"Missing an output location for shuffle $shuffleId"
-        logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, startPartition, 
errorMessage)
-      } else {
+    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since 
a merged shuffle
+    // partition consists of blocks merged in random order, we are unable to 
serve map index
+    // subrange requests. However, when a reduce task needs to fetch blocks 
from a subrange of
+    // map outputs, it usually indicates skewed partitions which push-based 
shuffle delegates
+    // to AQE to handle.
+    if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == 
mapStatuses.length) {
+      // We have MergeStatus and full range of mapIds are requested so return 
a merged block.
+      val numMaps = mapStatuses.length
+      mergeStatuses.get.zipWithIndex.slice(startPartition, 
endPartition).foreach {
+        case (mergeStatus, partId) =>
+          val remainingMapStatuses = if (mergeStatus != null && 
mergeStatus.totalSize > 0) {
+            // If MergeStatus is available for the given partition, add 
location of the
+            // pre-merged shuffle partition for this partition ID. Here we 
create a
+            // ShuffleBlockId with mapId being -1 to indicate this is a merged 
shuffle block.

Review comment:
       Add a constant (with value -1) and reference that instead of `-1` 
directly.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to