marin-ma commented on code in PR #12370:
URL: https://github.com/apache/gluten/pull/12370#discussion_r3506668067


##########
shims/spark33/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala:
##########
@@ -0,0 +1,1506 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import org.apache.spark.{MapOutputTracker, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.errors.SparkCoreErrors
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
+import org.apache.spark.network.shuffle._
+import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper}
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter
+import org.apache.spark.util.{TaskCompletionListener, Utils}
+
+import io.netty.util.internal.OutOfDirectMemoryError
+import org.apache.commons.io.IOUtils
+
+import javax.annotation.concurrent.GuardedBy
+
+import java.io.{InputStream, IOException}
+import java.nio.channels.ClosedByInterruptException
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit}
+import java.util.zip.CheckedInputStream
+
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+import scala.util.{Failure, Success}
+
+/**
+ * An iterator that fetches multiple blocks. For local blocks, it fetches from 
the local block
+ * manager. For remote blocks, it fetches them using the provided 
BlockTransferService.
+ *
+ * This creates an iterator of (BlockID, InputStream) tuples so the caller can 
handle blocks in a
+ * pipelined fashion as they are received.
+ *
+ * The implementation throttles the remote fetches so they don't exceed 
maxBytesInFlight to avoid
+ * using too much memory.
+ *
+ * @param context
+ *   [[TaskContext]], used for metrics update
+ * @param shuffleClient
+ *   [[BlockStoreClient]] for fetching remote blocks
+ * @param blockManager
+ *   [[BlockManager]] for reading local blocks
+ * @param blocksByAddress
+ *   list of blocks to fetch grouped by the [[BlockManagerId]]. For each block 
we also require two
+ *   info: 1. the size (in bytes as a long field) in order to throttle the 
memory usage; 2. the
+ *   mapIndex for this block, which indicate the index in the map stage. Note 
that zero-sized blocks
+ *   are already excluded, which happened in
+ *   [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
+ * @param mapOutputTracker
+ *   [[MapOutputTracker]] for falling back to fetching the original blocks if 
we fail to fetch
+ *   shuffle chunks when push based shuffle is enabled.
+ * @param streamWrapper
+ *   A function to wrap the returned input stream.
+ * @param maxBytesInFlight
+ *   max size (in bytes) of remote blocks to fetch at any given point.
+ * @param maxReqsInFlight
+ *   max number of remote requests to fetch blocks at any given point.
+ * @param maxBlocksInFlightPerAddress
+ *   max number of shuffle blocks being fetched at any given point for a given 
remote host:port.
+ * @param maxReqSizeShuffleToMem
+ *   max size (in bytes) of a request that can be shuffled to memory.
+ * @param maxAttemptsOnNettyOOM
+ *   The max number of a block could retry due to Netty OOM before throwing 
the fetch failure.
+ * @param detectCorrupt
+ *   whether to detect any corruption in fetched blocks.
+ * @param checksumEnabled
+ *   whether the shuffle checksum is enabled. When enabled, Spark will try to 
diagnose the cause of
+ *   the block corruption.
+ * @param checksumAlgorithm
+ *   the checksum algorithm that is used when calculating the checksum value 
for the block data.
+ * @param shuffleMetrics
+ *   used to report shuffle metrics.
+ * @param doBatchFetch
+ *   fetch continuous shuffle blocks from same executor in batch if the server 
side supports.
+ */
+final class GlutenShuffleBlockFetcherIterator(
+    context: TaskContext,
+    shuffleClient: BlockStoreClient,
+    blockManager: BlockManager,
+    mapOutputTracker: MapOutputTracker,
+    blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+    streamWrapper: (BlockId, InputStream) => InputStream,
+    maxBytesInFlight: Long,
+    maxReqsInFlight: Int,
+    maxBlocksInFlightPerAddress: Int,
+    val maxReqSizeShuffleToMem: Long,
+    maxAttemptsOnNettyOOM: Int,
+    detectCorrupt: Boolean,
+    detectCorruptUseExtraMemory: Boolean,
+    checksumEnabled: Boolean,
+    checksumAlgorithm: String,
+    shuffleMetrics: ShuffleReadMetricsReporter,
+    doBatchFetch: Boolean)
+  extends GlutenShuffleBlockFetcherIteratorBase
+  with DownloadFileManager
+  with Logging {
+
+  import ShuffleBlockFetcherIterator._
+
+  // Make remote requests at most maxBytesInFlight / 5 in length; the reason 
to keep them
+  // smaller than maxBytesInFlight is to allow multiple, parallel fetches from 
up to 5
+  // nodes, rather than blocking on reading output from one node.
+  private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L)
+
+  /**
+   * Total number of blocks to fetch.
+   */
+  private[this] var numBlocksToFetch = 0
+
+  /**
+   * The number of blocks processed by the caller. The iterator is exhausted 
when
+   * [[numBlocksProcessed]] == [[numBlocksToFetch]].
+   */
+  private[this] var numBlocksProcessed = 0
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  /** Host local blocks to fetch, excluding zero-sized blocks. */
+  private[this] val hostLocalBlocks = 
scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()
+
+  /**
+   * A queue to hold our results. This turns the asynchronous model provided by
+   * [[org.apache.spark.network.BlockTransferService]] into a synchronous 
model (iterator).
+   */
+  private[this] val results = new LinkedBlockingQueue[FetchResult]
+
+  /**
+   * Current [[FetchResult]] being processed per thread. We track this so we 
can release the current
+   * buffer in case of a runtime exception when processing the current buffer. 
Using
+   * ConcurrentHashMap to support concurrent access from multiple threads 
while allowing cleanup
+   * from any thread.
+   */
+  private[this] val currentResults: ConcurrentHashMap[Long, 
SuccessFetchResult] =
+    new ConcurrentHashMap[Long, SuccessFetchResult]()
+
+  /**
+   * Queue of fetch requests to issue; we'll pull requests off this gradually 
to make sure that the
+   * number of bytes in flight is limited to maxBytesInFlight.
+   */
+  private[this] val fetchRequests = new Queue[FetchRequest]
+
+  /**
+   * Queue of fetch requests which could not be issued the first time they 
were dequeued. These
+   * requests are tried again when the fetch constraints are satisfied.
+   */
+  private[this] val deferredFetchRequests = new HashMap[BlockManagerId, 
Queue[FetchRequest]]()
+
+  /** Current bytes in flight from our requests */
+  private[this] var bytesInFlight = 0L
+
+  /** Current number of requests in flight */
+  private[this] var reqsInFlight = 0
+
+  /** Current number of blocks in flight per host:port */
+  private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, 
Int]()
+
+  /**
+   * Count the retry times for the blocks due to Netty OOM. The block will 
stop retry if retry times
+   * has exceeded the [[maxAttemptsOnNettyOOM]].
+   */
+  private[this] val blockOOMRetryCounts = new HashMap[String, Int]
+
+  /**
+   * The blocks that can't be decompressed successfully, it is used to 
guarantee that we retry at
+   * most once for those corrupted blocks.
+   */
+  private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
+
+  /**
+   * Whether the iterator is still active. If isZombie is true, the callback 
interface will no
+   * longer place fetched blocks into [[results]].
+   */
+  @GuardedBy("this")
+  private[this] var isZombie = false
+
+  /**
+   * A set to store the files used for shuffling remote huge blocks. Files in 
this set will be
+   * deleted when cleanup. This is a layer of defensiveness against disk file 
leaks.
+   */
+  @GuardedBy("this")
+  private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]()
+
+  private[this] val onCompleteCallback = new 
GlutenShuffleFetchCompletionListener(this)
+
+  private[this] val pushBasedFetchHelper = new GlutenPushBasedFetchHelper(
+    this,
+    shuffleClient,
+    blockManager,
+    mapOutputTracker)
+
+  initialize()
+
+  // Decrements the buffer reference count.
+  // The currentResult is removed from the map to prevent releasing the buffer 
again on cleanup()
+  private[storage] def releaseCurrentResultBuffer(): Unit = {
+    val threadId = Thread.currentThread().getId
+    // Release the current buffer if necessary
+    val result = currentResults.remove(threadId)
+    if (result != null) {
+      result.buf.release()
+    }
+  }
+
+  override def createTempFile(transportConf: TransportConf): DownloadFile = {
+    // we never need to do any encryption or decryption here, regardless of 
configs, because that
+    // is handled at another layer in the code.  When encryption is enabled, 
shuffle data is written
+    // to disk encrypted in the first place, and sent over the network still 
encrypted.
+    new SimpleDownloadFile(
+      blockManager.diskBlockManager.createTempLocalBlock()._2,
+      transportConf)
+  }
+
+  override def registerTempFileToClean(file: DownloadFile): Boolean = 
synchronized {
+    if (isZombie) {
+      false
+    } else {
+      shuffleFilesSet += file
+      true
+    }
+  }
+
+  /**
+   * Mark the iterator as zombie, and release all buffers that haven't been 
deserialized yet.
+   */
+  private[storage] def cleanup(): Unit = {
+    synchronized {
+      isZombie = true
+    }
+    releaseCurrentResultBuffer()
+    // Release buffers in the results queue
+    val iter = results.iterator()
+    while (iter.hasNext) {
+      val result = iter.next()
+      result match {
+        case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) =>
+          if (address != blockManager.blockManagerId) {
+            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
+              shuffleMetrics.incLocalBlocksFetched(1)
+              shuffleMetrics.incLocalBytesRead(buf.size)
+            } else {
+              shuffleMetrics.incRemoteBytesRead(buf.size)
+              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+              }
+              shuffleMetrics.incRemoteBlocksFetched(1)
+            }
+          }
+          buf.release()
+        case _ =>
+      }
+    }
+    shuffleFilesSet.foreach {
+      file =>
+        if (!file.delete()) {
+          logWarning("Failed to cleanup shuffle fetch temp file " + 
file.path())
+        }
+    }
+  }
+
+  private[this] def sendRequest(req: FetchRequest): Unit = {
+    logDebug("Sending request for %d blocks (%s) from %s".format(
+      req.blocks.size,
+      Utils.bytesToString(req.size),
+      req.address.hostPort))
+    bytesInFlight += req.size
+    reqsInFlight += 1
+
+    // so we can look up the block info of each blockID
+    val infoMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, 
(size, mapIndex))
+    }.toMap
+    val remainingBlocks = new HashSet[String]() ++= infoMap.keys
+    val deferredBlocks = new ArrayBuffer[String]()
+    val blockIds = req.blocks.map(_.blockId.toString)
+    val address = req.address
+
+    @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
+      if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
+        val blocks = deferredBlocks.map {
+          blockId =>
+            val (size, mapIndex) = infoMap(blockId)
+            FetchBlockInfo(BlockId(blockId), size, mapIndex)
+        }
+        results.put(DeferFetchRequestResult(FetchRequest(address, 
blocks.toSeq)))
+        deferredBlocks.clear()
+      }
+    }
+
+    val blockFetchingListener = new BlockFetchingListener {
+      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): 
Unit = {
+        // Only add the buffer to results queue if the iterator is not zombie,
+        // i.e. cleanup() has not been called yet.
+        GlutenShuffleBlockFetcherIterator.this.synchronized {
+          if (!isZombie) {
+            // Increment the ref count because we need to pass this to a 
different thread.
+            // This needs to be released after use.
+            buf.retain()
+            remainingBlocks -= blockId
+            blockOOMRetryCounts.remove(blockId)
+            results.put(new SuccessFetchResult(
+              BlockId(blockId),
+              infoMap(blockId)._2,
+              address,
+              infoMap(blockId)._1,
+              buf,
+              remainingBlocks.isEmpty))
+            logDebug("remainingBlocks: " + remainingBlocks)
+            enqueueDeferredFetchRequestIfNecessary()
+          }
+        }
+        logTrace(s"Got remote block $blockId after 
${Utils.getUsedTimeNs(startTimeNs)}")
+      }
+
+      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
+        GlutenShuffleBlockFetcherIterator.this.synchronized {
+          logError(s"Failed to get block(s) from 
${req.address.host}:${req.address.port}", e)
+          e match {
+            // SPARK-27991: Catch the Netty OOM and set the flag 
`isNettyOOMOnShuffle` (shared among
+            // tasks) to true as early as possible. The pending fetch requests 
won't be sent
+            // afterwards until the flag is set to false on:
+            // 1) the Netty free memory >= maxReqSizeShuffleToMem
+            //    - we'll check this whenever there's a fetch request succeeds.
+            // 2) the number of in-flight requests becomes 0
+            //    - we'll check this in `fetchUpToMaxBytes` whenever it's 
invoked.
+            // Although Netty memory is shared across multiple modules, e.g., 
shuffle, rpc, the flag
+            // only takes effect for the shuffle due to the implementation 
simplicity concern.
+            // And we'll buffer the consecutive block failures caused by the 
OOM error until there's
+            // no remaining blocks in the current request. Then, we'll package 
these blocks into
+            // a same fetch request for the retry later. In this way, instead 
of creating the fetch
+            // request per block, it would help reduce the concurrent 
connections and data loads
+            // pressure at remote server.
+            // Note that catching OOM and do something based on it is only a 
workaround for
+            // handling the Netty OOM issue, which is not the best way towards 
memory management.
+            // We can get rid of it when we find a way to manage Netty's 
memory precisely.
+            case _: OutOfDirectMemoryError
+                if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < 
maxAttemptsOnNettyOOM =>
+              if (!isZombie) {
+                val failureTimes = blockOOMRetryCounts(blockId)
+                blockOOMRetryCounts(blockId) += 1
+                if (isNettyOOMOnShuffle.compareAndSet(false, true)) {
+                  // The fetcher can fail remaining blocks in batch for the 
same error. So we only
+                  // log the warning once to avoid flooding the logs.
+                  logInfo(s"Block $blockId has failed $failureTimes times " +
+                    s"due to Netty OOM, will retry")
+                }
+                remainingBlocks -= blockId
+                deferredBlocks += blockId
+                enqueueDeferredFetchRequestIfNecessary()
+              }
+
+            case _ =>
+              val block = BlockId(blockId)
+              if (block.isShuffleChunk) {
+                remainingBlocks -= blockId
+                results.put(FallbackOnPushMergedFailureResult(
+                  block,
+                  address,
+                  infoMap(blockId)._1,
+                  remainingBlocks.isEmpty))
+              } else {
+                results.put(FailureFetchResult(block, infoMap(blockId)._2, 
address, e))
+              }
+          }
+        }
+      }
+    }
+
+    // Fetch remote shuffle blocks to disk when the request is too large. 
Since the shuffle data is
+    // already encrypted and compressed over the wire(w.r.t. the related 
configs), we can just fetch
+    // the data and write it to file directly.
+    if (req.size > maxReqSizeShuffleToMem) {
+      shuffleClient.fetchBlocks(
+        address.host,
+        address.port,
+        address.executorId,
+        blockIds.toArray,
+        blockFetchingListener,
+        this)
+    } else {
+      shuffleClient.fetchBlocks(
+        address.host,
+        address.port,
+        address.executorId,
+        blockIds.toArray,
+        blockFetchingListener,
+        null)
+    }
+  }
+
+  /**
+   * This is called from initialize and also from the fallback which is 
triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, 
Seq[(BlockId, Long, Int)]],
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): 
ArrayBuffer[FetchRequest] = {
+    logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+      + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: 
$maxBlocksInFlightPerAddress")
+
+    // Partition to local, host-local, push-merged-local, remote (includes 
push-merged-remote)
+    // blocks.Remote blocks are further split into FetchRequests of size at 
most maxBytesInFlight
+    // in order to limit the amount of data in flight
+    val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    var localBlockBytes = 0L
+    var hostLocalBlockBytes = 0L
+    var numHostLocalBlocks = 0
+    var pushMergedLocalBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
+
+    val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
+    val localExecIds = Set(blockManager.blockManagerId.executorId, fallback)
+    for ((address, blockInfos) <- blocksByAddress) {
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
+        // These are push-merged blocks or shuffle chunks of these blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          numBlocksToFetch += blockInfos.size
+          pushMergedLocalBlocks ++= blockInfos.map(_._1)
+          pushMergedLocalBlockBytes += blockInfos.map(_._2).sum
+        } else {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (localExecIds.contains(address.executorId)) {
+        val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
+          doBatchFetch)
+        numBlocksToFetch += mergedBlockInfos.size
+        localBlocks ++= mergedBlockInfos.map(info => (info.blockId, 
info.mapIndex))
+        localBlockBytes += mergedBlockInfos.map(_.size).sum
+      } else if (
+        blockManager.hostLocalDirManager.isDefined &&
+        address.host == blockManager.blockManagerId.host
+      ) {
+        val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
+          doBatchFetch)
+        numBlocksToFetch += mergedBlockInfos.size
+        val blocksForAddress =
+          mergedBlockInfos.map(info => (info.blockId, info.size, 
info.mapIndex))
+        hostLocalBlocksByExecutor += address -> blocksForAddress
+        numHostLocalBlocks += blocksForAddress.size
+        hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
+      } else {
+        val (_, timeCost) = Utils.timeTakenMs[Unit] {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+        logDebug(s"Collected remote fetch requests for $address in $timeCost 
ms")
+      }
+    }
+    val (remoteBlockBytes, numRemoteBlocks) =
+      collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 
+ y.blocks.size))
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(
+      blocksToFetchCurrentIteration == localBlocks.size +
+        numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't 
equal to the sum " +
+        s"of the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks $numHostLocalBlocks " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} 
" +
+        s"+ the number of remote blocks $numRemoteBlocks "
+    )
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local 
and " +
+      s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"push-merged-local and $numRemoteBlocks 
(${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values
+      .flatMap(infos => infos.map(info => (info._1, info._3)))
+    collectedRemoteRequests
+  }
+
+  private def createFetchRequest(
+      blocks: Seq[FetchBlockInfo],
+      address: BlockManagerId,
+      forMergedMetas: Boolean): FetchRequest = {
+    logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address 
"
+      + s"with ${blocks.size} blocks")
+    FetchRequest(address, blocks, forMergedMetas)
+  }
+
+  private def createFetchRequests(
+      curBlocks: Seq[FetchBlockInfo],
+      address: BlockManagerId,
+      isLast: Boolean,
+      collectedRemoteRequests: ArrayBuffer[FetchRequest],
+      enableBatchFetch: Boolean,
+      forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = {
+    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, 
enableBatchFetch)
+    numBlocksToFetch += mergedBlocks.size
+    val retBlocks = new ArrayBuffer[FetchBlockInfo]
+    if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
+      collectedRemoteRequests += createFetchRequest(mergedBlocks, address, 
forMergedMetas)
+    } else {
+      mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach {
+        blocks =>
+          if (blocks.length == maxBlocksInFlightPerAddress || isLast) {
+            collectedRemoteRequests += createFetchRequest(blocks, address, 
forMergedMetas)
+          } else {
+            // The last group does not exceed `maxBlocksInFlightPerAddress`. 
Put it back
+            // to `curBlocks`.
+            retBlocks ++= blocks
+            numBlocksToFetch -= blocks.size
+          }
+      }
+    }
+    retBlocks
+  }
+
+  private def collectFetchRequests(
+      address: BlockManagerId,
+      blockInfos: Seq[(BlockId, Long, Int)],
+      collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = {
+    val iterator = blockInfos.iterator
+    var curRequestSize = 0L
+    var curBlocks = new ArrayBuffer[FetchBlockInfo]()
+
+    while (iterator.hasNext) {
+      val (blockId, size, mapIndex) = iterator.next()
+      curBlocks += FetchBlockInfo(blockId, size, mapIndex)
+      curRequestSize += size
+      blockId match {
+        // Either all blocks are push-merged blocks, shuffle chunks, or 
original blocks.
+        // Based on these types, we decide to do batch fetch and create 
FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _, _) =>
+          if (
+            curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress
+          ) {
+            curBlocks = createFetchRequests(
+              curBlocks.toSeq,
+              address,
+              isLast = false,
+              collectedRemoteRequests,
+              enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleMergedBlockId(_, _, _) =>
+          if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(
+              curBlocks.toSeq,
+              address,
+              isLast = false,
+              collectedRemoteRequests,
+              enableBatchFetch = false,
+              forMergedMetas = true)
+          }
+        case _ =>
+          // For batch fetch, the actual block in flight should count for 
merged block.
+          val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= 
maxBlocksInFlightPerAddress
+          if (curRequestSize >= targetRemoteRequestSize || 
mayExceedsMaxBlocks) {
+            curBlocks = createFetchRequests(
+              curBlocks.toSeq,
+              address,
+              isLast = false,
+              collectedRemoteRequests,
+              doBatchFetch)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+      }
+    }
+    // Add in the final request
+    if (curBlocks.nonEmpty) {
+      val (enableBatchFetch, forMergedMetas) = {
+        curBlocks.head.blockId match {
+          case ShuffleBlockChunkId(_, _, _, _) => (false, false)
+          case ShuffleMergedBlockId(_, _, _) => (false, true)
+          case _ => (doBatchFetch, false)
+        }
+      }
+      createFetchRequests(
+        curBlocks.toSeq,
+        address,
+        isLast = true,
+        collectedRemoteRequests,
+        enableBatchFetch = enableBatchFetch,
+        forMergedMetas = forMergedMetas)
+    }
+  }
+
+  private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit 
= {
+    if (blockSize < 0) {
+      throw BlockException(blockId, "Negative block size " + size)

Review Comment:
   Copied from spark.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to