mridulm commented on a change in pull request #30480:
URL: https://github.com/apache/spark/pull/30480#discussion_r602044852



##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -833,33 +1106,44 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
    *
    * (It would be nice to remove this restriction in the future.)
    */
-  private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = 
{
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching 
them")
+  private def getStatuses(
+      shuffleId: Int, conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) 
= {
+    val mapOutputStatuses = mapStatuses.get(shuffleId).orNull
+    val mergeResultStatuses = mergeStatuses.get(shuffleId).orNull
+    if (mapOutputStatuses == null || (fetchMergeResult && mergeResultStatuses 
== null)) {
+      logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", 
fetching them")
       val startTimeNs = System.nanoTime()
       fetchingLock.withLock(shuffleId) {
-        var fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+        var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull
+        if (fetchedMapStatuses == null) {
+          logInfo("Doing the map fetch; tracker endpoint = " + trackerEndpoint)
           val fetchedBytes = 
askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          fetchedStatuses = 
MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
-          logInfo("Got the output locations")
-          mapStatuses.put(shuffleId, fetchedStatuses)
+          fetchedMapStatuses = 
MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf)
+          logInfo("Got the map output locations")
+          mapStatuses.put(shuffleId, fetchedMapStatuses)
         }
-        logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
+        var fetchedMergeStatues = mergeStatuses.get(shuffleId).orNull
+        if (fetchMergeResult && fetchedMergeStatues == null) {
+          logInfo("Doing the merge fetch; tracker endpoint = " + 
trackerEndpoint)
+          val fetchedBytes = 
askTracker[Array[Byte]](GetMergeResultStatuses(shuffleId))
+          fetchedMergeStatues = 
MapOutputTracker.deserializeOutputStatuses(fetchedBytes, conf)
+          logInfo("Got the merge output locations")
+          mergeStatuses.put(shuffleId, fetchedMergeStatues)
+        }
+        logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId 
took " +
           s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} 
ms")
-        fetchedStatuses
+        (fetchedMapStatuses, fetchedMergeStatues)
       }

Review comment:
       I would expect multiple rpc's to not be the preferred option given the 
impact on driver, but code simplicity needs to be balanced against that.
   +CC @JoshRosen, @Ngone51 who last made changes here. Any thoughts on this ?

##########
File path: core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.network.shuffle.protocol.MergeStatuses
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
+
+/**
+ * The status for the result of merging shuffle partition blocks per 
individual shuffle partition
+ * maintained by the scheduler. The scheduler would separate the
+ * [[org.apache.spark.network.shuffle.protocol.MergeStatuses]] received from
+ * ExternalShuffleService into individual [[MergeStatus]] which is maintained 
inside
+ * MapOutputTracker to be served to the reducers when they start fetching 
shuffle partition
+ * blocks. Note that, the reducers are ultimately fetching individual chunks 
inside a merged
+ * shuffle file, as explained in 
[[org.apache.spark.network.shuffle.RemoteBlockPushResolver]].
+ * Between the scheduler maintained MergeStatus and the shuffle service 
maintained per shuffle
+ * partition meta file, we are effectively dividing the metadata for a 
push-based shuffle into
+ * 2 layers. The scheduler would track the top-level metadata at the shuffle 
partition level
+ * with MergeStatus, and the shuffle service would maintain the partition 
level metadata about
+ * how to further divide a merged shuffle partition into multiple chunks with 
the per-partition
+ * meta file. This helps to reduce the amount of data the scheduler needs to 
maintain for
+ * push-based shuffle.
+ */
+private[spark] class MergeStatus(
+    private[this] var loc: BlockManagerId,
+    private[this] var mapTracker: RoaringBitmap,
+    private[this] var size: Long)
+  extends Externalizable with ShuffleOutputStatus {
+
+  protected def this() = this(null, null, -1) // For deserialization only
+
+  def location: BlockManagerId = loc
+
+  def totalSize: Long = size
+
+  def tracker: RoaringBitmap = mapTracker
+
+  /**
+   * Get the list of mapper IDs for missing mapper partition blocks that are 
not merged.
+   * The reducer will use this information to decide which shuffle partition 
blocks to
+   * fetch in the original way.
+   */
+  def getMissingMaps(numMaps: Int): Seq[Int] = {
+    (0 until numMaps).filter(i => !mapTracker.contains(i))
+  }
+
+  /**
+   * Get the number of missing map outputs for missing mapper partition blocks 
that are not merged.
+   */
+  def getNumMissingMapOutputs(numMaps: Int): Int = {
+    getMissingMaps(numMaps).length

Review comment:
       nit: avoid creating temp seq here and compute directly - 
`getNumMissingMapOutputs` is called a large number of times.
   
   ```suggestion
       (0 until numMaps).count(i => !mapTracker.contains(i))
   ```
   
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to