otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656791069



##########
File path: 
core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, 
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the 
push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle 
chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged 
by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, 
RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false 
otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != 
blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false 
otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == 
blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked 
and the iterator
+   * processes a response of type 
[[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked 
and the iterator
+   * processes a response of type 
[[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = 
{
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked 
and the iterator
+   * processes a response of type 
[[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds 
that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and 
only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only 
contains requests to fetch
+   *            metadata of push-merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: 
MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, 
$reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          
iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, 
meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of push-merged block for 
($shuffleId, " +
+              s"$reduceId) from ${req.address.host}:${req.address.port}", 
exception)
+            iterator.addToResultsQueue(
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, 
address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: 
Throwable): Unit = {
+        logError(s"Failed to get the meta of push-merged block for 
($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, 
shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It 
fetches all the
+   * outstanding push-merged local blocks.
+   * @param pushMergedLocalBlocks set of identified merged local blocks and 
their sizes.
+   */
+  def fetchAllPushMergedLocalBlocks(
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (pushMergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, 
pushMergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the push-merged blocks dirs if they are not in the cache and 
eventually fetch push-merged
+   * local blocks.
+   */
+  private def fetchPushMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local push-merged blocks with cached executors dir: 
" +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      pushMergedLocalBlocks.foreach { blockId =>
+        fetchPushMergedLocalBlock(blockId, cachedMergerDirs.get,
+          localShuffleMergerBlockMgrId)
+      }
+    } else {
+      logDebug(s"Asynchronous fetching local push-merged blocks without cached 
executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          pushMergedLocalBlocks.takeWhile {

Review comment:
       I have added a UT to catch this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to