venkata91 commented on a change in pull request #30480:
URL: https://github.com/apache/spark/pull/30480#discussion_r611903588



##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -833,33 +1106,44 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
    *
    * (It would be nice to remove this restriction in the future.)
    */
-  private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = 
{
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching 
them")
+  private def getStatuses(
+      shuffleId: Int, conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) 
= {
+    val mapOutputStatuses = mapStatuses.get(shuffleId).orNull
+    val mergeResultStatuses = mergeStatuses.get(shuffleId).orNull
+    if (mapOutputStatuses == null || (fetchMergeResult && mergeResultStatuses 
== null)) {
+      logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", 
fetching them")
       val startTimeNs = System.nanoTime()
       fetchingLock.withLock(shuffleId) {
-        var fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+        var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull
+        if (fetchedMapStatuses == null) {
+          logInfo("Doing the map fetch; tracker endpoint = " + trackerEndpoint)
           val fetchedBytes = 
askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          fetchedStatuses = 
MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
-          logInfo("Got the output locations")
-          mapStatuses.put(shuffleId, fetchedStatuses)
+          fetchedMapStatuses = 
MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf)
+          logInfo("Got the map output locations")
+          mapStatuses.put(shuffleId, fetchedMapStatuses)
         }
-        logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
+        var fetchedMergeStatues = mergeStatuses.get(shuffleId).orNull
+        if (fetchMergeResult && fetchedMergeStatues == null) {
+          logInfo("Doing the merge fetch; tracker endpoint = " + 
trackerEndpoint)
+          val fetchedBytes = 
askTracker[Array[Byte]](GetMergeResultStatuses(shuffleId))
+          fetchedMergeStatues = 
MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf)
+          logInfo("Got the merge output locations")
+          mergeStatuses.put(shuffleId, fetchedMergeStatues)
+        }
+        logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId 
took " +
           s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} 
ms")
-        fetchedStatuses
+        (fetchedMapStatuses, fetchedMergeStatues)
       }

Review comment:
       Makes sense @mridulm Instead of `Array[Array[Byte]]` I used a tuple of 
`(Array[Byte], Array[Byte])`. cc @Victsm 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to