Copilot commented on code in PR #12370: URL: https://github.com/apache/gluten/pull/12370#discussion_r3505163705
########## cpp/velox/utils/CachedBatchQueue.h: ########## @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <glog/logging.h> +#include "velox/common/base/Exceptions.h" + +#include <condition_variable> +#include <memory> +#include <mutex> +#include <queue> + +namespace gluten { + +template <typename T> +class CachedBatchQueue { + public: + explicit CachedBatchQueue(const int64_t capacity) : capacity_(capacity) {} + + void put(std::shared_ptr<T> batch) { + std::unique_lock<std::mutex> lock(mtx_); + VELOX_CHECK(!noMoreBatches_, "Cannot put batch after noMoreBatches() is called"); + + const auto batchSize = batch->numBytes(); + VELOX_CHECK_LE(batchSize, capacity_, "Batch size exceeds queue capacity"); + + notFull_.wait(lock, [&]() { return totalSize_ + batchSize <= capacity_; }); + + queue_.push(std::move(batch)); + totalSize_ += batchSize; + + notEmpty_.notify_one(); + } + + std::shared_ptr<T> get() { + std::unique_lock<std::mutex> lock(mtx_); + notEmpty_.wait(lock, [&]() { return noMoreBatches_ || !queue_.empty(); }); + + if (queue_.empty()) { + return nullptr; + } + auto batch = std::move(queue_.front()); + LOG(INFO) << "Trying to get from cached buffer queue. Queue length: " << queue_.size() + << ", total size in queue: " << totalSize_ << ", current batch size: " << batch->numBytes() << std::endl; + Review Comment: `CachedBatchQueue.get()` logs at INFO for every dequeued batch (and also flushes via `std::endl`). This is likely to be extremely noisy and expensive in the shuffle fast path. Consider removing this log or downgrading to `VLOG` without `std::endl`. ########## cpp/velox/utils/CachedBatchQueue.h: ########## @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <glog/logging.h> +#include "velox/common/base/Exceptions.h" + +#include <condition_variable> +#include <memory> +#include <mutex> +#include <queue> + +namespace gluten { + +template <typename T> +class CachedBatchQueue { + public: + explicit CachedBatchQueue(const int64_t capacity) : capacity_(capacity) {} + + void put(std::shared_ptr<T> batch) { + std::unique_lock<std::mutex> lock(mtx_); + VELOX_CHECK(!noMoreBatches_, "Cannot put batch after noMoreBatches() is called"); + + const auto batchSize = batch->numBytes(); + VELOX_CHECK_LE(batchSize, capacity_, "Batch size exceeds queue capacity"); + + notFull_.wait(lock, [&]() { return totalSize_ + batchSize <= capacity_; }); + Review Comment: `CachedBatchQueue.put()` can deadlock if `noMoreBatches()` is called while a producer is blocked in `notFull_.wait(...)`: `noMoreBatches()` notifies `notFull_`, but the wait predicate never becomes true, so the producer can wait forever. Include `noMoreBatches_` in the predicate and re-check after waking to fail fast. ########## gluten-arrow/src/main/scala/org/apache/spark/storage/SparkInputStreamUtil.scala: ########## @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.InputStream object SparkInputStreamUtil { - def unwrapBufferReleasingInputStream(in: BufferReleasingInputStream): InputStream = { + def unwrapBufferReleasingInputStream(in: GlutenBufferReleasingInputStream): InputStream = { in.delegate } Review Comment: `unwrapBufferReleasingInputStream` takes `GlutenBufferReleasingInputStream`, which forces other modules (like Java `JniByteInputStreams`) to reference that shim-internal type directly. Since `GlutenBufferReleasingInputStream` is currently declared `private` in the shims, this makes cross-module compilation fragile. Consider changing this utility to accept a plain `InputStream` and do the unwrapping internally (pattern match) to avoid leaking the shim wrapper type. ########## cpp/velox/compute/VeloxBackend.cc: ########## @@ -294,12 +294,19 @@ void VeloxBackend::init( registerShuffleDictionaryWriterFactory([](MemoryManager* memoryManager, arrow::util::Codec* codec) { return std::make_unique<ArrowShuffleDictionaryWriter>(memoryManager, codec); }); + + readerThreadPool_ = std::make_unique<ReaderThreadPool>( + backendConf_->get<int32_t>(kShuffleReaderThreads, std::thread::hardware_concurrency())); Review Comment: `std::thread::hardware_concurrency()` is permitted to return 0, and the config could also be set to 0. Creating `ReaderThreadPool(0)` would result in zero reader tasks and the GPU shuffle reader would block forever on `CachedBatchQueue::get()` (since `noMoreBatches()` is only called by reader threads). Clamp the thread count to at least 1 (and ideally validate the conf value). ########## shims/spark33/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala: ########## @@ -0,0 +1,1506 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.{MapOutputTracker, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} +import org.apache.spark.network.util.TransportConf +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{TaskCompletionListener, Utils} + +import io.netty.util.internal.OutOfDirectMemoryError +import org.apache.commons.io.IOUtils + +import javax.annotation.concurrent.GuardedBy + +import java.io.{InputStream, IOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.zip.CheckedInputStream + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.util.{Failure, Success} + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * [[TaskContext]], used for metrics update + * @param shuffleClient + * [[BlockStoreClient]] for fetching remote blocks + * @param blockManager + * [[BlockManager]] for reading local blocks + * @param blocksByAddress + * list of blocks to fetch grouped by the [[BlockManagerId]]. For each block we also require two + * info: 1. the size (in bytes as a long field) in order to throttle the memory usage; 2. the + * mapIndex for this block, which indicate the index in the map stage. Note that zero-sized blocks + * are already excluded, which happened in + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker + * [[MapOutputTracker]] for falling back to fetching the original blocks if we fail to fetch + * shuffle chunks when push based shuffle is enabled. + * @param streamWrapper + * A function to wrap the returned input stream. + * @param maxBytesInFlight + * max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight + * max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress + * max number of shuffle blocks being fetched at any given point for a given remote host:port. + * @param maxReqSizeShuffleToMem + * max size (in bytes) of a request that can be shuffled to memory. + * @param maxAttemptsOnNettyOOM + * The max number of a block could retry due to Netty OOM before throwing the fetch failure. + * @param detectCorrupt + * whether to detect any corruption in fetched blocks. + * @param checksumEnabled + * whether the shuffle checksum is enabled. When enabled, Spark will try to diagnose the cause of + * the block corruption. + * @param checksumAlgorithm + * the checksum algorithm that is used when calculating the checksum value for the block data. + * @param shuffleMetrics + * used to report shuffle metrics. + * @param doBatchFetch + * fetch continuous shuffle blocks from same executor in batch if the server side supports. + */ +final class GlutenShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean) + extends GlutenShuffleBlockFetcherIteratorBase + with DownloadFileManager + with Logging { + + import ShuffleBlockFetcherIterator._ + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + + /** + * Total number of blocks to fetch. + */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTimeNs = System.nanoTime() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + /** + * Current [[FetchResult]] being processed per thread. We track this so we can release the current + * buffer in case of a runtime exception when processing the current buffer. Using + * ConcurrentHashMap to support concurrent access from multiple threads while allowing cleanup + * from any thread. + */ + private[this] val currentResults: ConcurrentHashMap[Long, SuccessFetchResult] = + new ConcurrentHashMap[Long, SuccessFetchResult]() + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that the + * number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[FetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if retry times + * has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry at + * most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() + + private[this] val onCompleteCallback = new GlutenShuffleFetchCompletionListener(this) + + private[this] val pushBasedFetchHelper = new GlutenPushBasedFetchHelper( + this, + shuffleClient, + blockManager, + mapOutputTracker) + + initialize() + + // Decrements the buffer reference count. + // The currentResult is removed from the map to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + val threadId = Thread.currentThread().getId + // Release the current buffer if necessary + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile( + blockManager.diskBlockManager.createTempLocalBlock()._2, + transportConf) + } + + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[storage] def cleanup(): Unit = { + synchronized { + isZombie = true + } + releaseCurrentResultBuffer() + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + if (hostLocalBlocks.contains(blockId -> mapIndex)) { + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + } + } + buf.release() + case _ => + } + } + shuffleFilesSet.foreach { + file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) + } + } + } + + private[this] def sendRequest(req: FetchRequest): Unit = { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, + Utils.bytesToString(req.size), + req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() + val blockIds = req.blocks.map(_.blockId.toString) + val address = req.address + + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { + blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq))) + deferredBlocks.clear() + } + } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + GlutenShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) + results.put(new SuccessFetchResult( + BlockId(blockId), + infoMap(blockId)._2, + address, + infoMap(blockId)._1, + buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() + } + } + logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + GlutenShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo(s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + results.put(FallbackOnPushMergedFailureResult( + block, + address, + infoMap(blockId)._1, + remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } + } + } + } + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this) + } else { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + null) + } + } + + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + var localBlockBytes = 0L + var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch + + val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. + if (address.host == blockManager.blockManagerId.host) { + numBlocksToFetch += blockInfos.size + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._2).sum + } else { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + } else if (localExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if ( + blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host + ) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + numHostLocalBlocks += blocksForAddress.size + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + val (_, timeCost) = Utils.timeTakenMs[Unit] { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + logDebug(s"Collected remote fetch requests for $address in $timeCost ms") + } + } + val (remoteBlockBytes, numRemoteBlocks) = + collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert( + blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks $numHostLocalBlocks " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks $numRemoteBlocks " + ) + logInfo(s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap(infos => infos.map(info => (info._1, info._3))) + collectedRemoteRequests + } + + private def createFetchRequest( + blocks: Seq[FetchBlockInfo], + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { + logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + + s"with ${blocks.size} blocks") + FetchRequest(address, blocks, forMergedMetas) + } + + private def createFetchRequests( + curBlocks: Seq[FetchBlockInfo], + address: BlockManagerId, + isLast: Boolean, + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) + numBlocksToFetch += mergedBlocks.size + val retBlocks = new ArrayBuffer[FetchBlockInfo] + if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) + } else { + mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { + blocks => + if (blocks.length == maxBlocksInFlightPerAddress || isLast) { + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) + } else { + // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back + // to `curBlocks`. + retBlocks ++= blocks + numBlocksToFetch -= blocks.size + } + } + } + retBlocks + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _, _) => + if ( + curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress + ) { + curBlocks = createFetchRequests( + curBlocks.toSeq, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleMergedBlockId(_, _, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests( + curBlocks.toSeq, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false, + forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests( + curBlocks.toSeq, + address, + isLast = false, + collectedRemoteRequests, + doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests( + curBlocks.toSeq, + address, + isLast = true, + collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, + forMergedMetas = forMergedMetas) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) Review Comment: `assertPositiveBlockSize` throws using an undefined identifier `size`, which will not compile. It should reference the `blockSize` parameter when building the error message. ########## gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java: ########## @@ -58,8 +58,8 @@ public static JniByteInputStream create(InputStream in) { static InputStream unwrapSparkInputStream(InputStream in) { InputStream unwrapped = in; - if (unwrapped instanceof BufferReleasingInputStream) { - final BufferReleasingInputStream brin = (BufferReleasingInputStream) unwrapped; + if (unwrapped instanceof GlutenBufferReleasingInputStream) { + final GlutenBufferReleasingInputStream brin = (GlutenBufferReleasingInputStream) unwrapped; unwrapped = org.apache.spark.storage.SparkInputStreamUtil.unwrapBufferReleasingInputStream(brin); } Review Comment: This `instanceof GlutenBufferReleasingInputStream` cast introduces a compile-time dependency on a shim-internal type that is currently declared as a `private class` in the Scala shims. That makes compilation fragile / likely to fail. Prefer moving the unwrapping logic behind a public helper that accepts `InputStream` (so this code doesn’t need to reference the shim class), or make the wrapper type public across shims. ########## shims/spark33/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala: ########## @@ -0,0 +1,1506 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.{MapOutputTracker, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} +import org.apache.spark.network.util.TransportConf +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{TaskCompletionListener, Utils} + +import io.netty.util.internal.OutOfDirectMemoryError +import org.apache.commons.io.IOUtils + +import javax.annotation.concurrent.GuardedBy + +import java.io.{InputStream, IOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.zip.CheckedInputStream + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.util.{Failure, Success} + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * [[TaskContext]], used for metrics update + * @param shuffleClient + * [[BlockStoreClient]] for fetching remote blocks + * @param blockManager + * [[BlockManager]] for reading local blocks + * @param blocksByAddress + * list of blocks to fetch grouped by the [[BlockManagerId]]. For each block we also require two + * info: 1. the size (in bytes as a long field) in order to throttle the memory usage; 2. the + * mapIndex for this block, which indicate the index in the map stage. Note that zero-sized blocks + * are already excluded, which happened in + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker + * [[MapOutputTracker]] for falling back to fetching the original blocks if we fail to fetch + * shuffle chunks when push based shuffle is enabled. + * @param streamWrapper + * A function to wrap the returned input stream. + * @param maxBytesInFlight + * max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight + * max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress + * max number of shuffle blocks being fetched at any given point for a given remote host:port. + * @param maxReqSizeShuffleToMem + * max size (in bytes) of a request that can be shuffled to memory. + * @param maxAttemptsOnNettyOOM + * The max number of a block could retry due to Netty OOM before throwing the fetch failure. + * @param detectCorrupt + * whether to detect any corruption in fetched blocks. + * @param checksumEnabled + * whether the shuffle checksum is enabled. When enabled, Spark will try to diagnose the cause of + * the block corruption. + * @param checksumAlgorithm + * the checksum algorithm that is used when calculating the checksum value for the block data. + * @param shuffleMetrics + * used to report shuffle metrics. + * @param doBatchFetch + * fetch continuous shuffle blocks from same executor in batch if the server side supports. + */ +final class GlutenShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean) + extends GlutenShuffleBlockFetcherIteratorBase + with DownloadFileManager + with Logging { + + import ShuffleBlockFetcherIterator._ + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + + /** + * Total number of blocks to fetch. + */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTimeNs = System.nanoTime() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + /** + * Current [[FetchResult]] being processed per thread. We track this so we can release the current + * buffer in case of a runtime exception when processing the current buffer. Using + * ConcurrentHashMap to support concurrent access from multiple threads while allowing cleanup + * from any thread. + */ + private[this] val currentResults: ConcurrentHashMap[Long, SuccessFetchResult] = + new ConcurrentHashMap[Long, SuccessFetchResult]() + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that the + * number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[FetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if retry times + * has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry at + * most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() + + private[this] val onCompleteCallback = new GlutenShuffleFetchCompletionListener(this) + + private[this] val pushBasedFetchHelper = new GlutenPushBasedFetchHelper( + this, + shuffleClient, + blockManager, + mapOutputTracker) + + initialize() + + // Decrements the buffer reference count. + // The currentResult is removed from the map to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + val threadId = Thread.currentThread().getId + // Release the current buffer if necessary + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile( + blockManager.diskBlockManager.createTempLocalBlock()._2, + transportConf) + } + + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[storage] def cleanup(): Unit = { + synchronized { + isZombie = true + } + releaseCurrentResultBuffer() + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + if (hostLocalBlocks.contains(blockId -> mapIndex)) { + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + } + } + buf.release() + case _ => + } + } + shuffleFilesSet.foreach { + file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) + } + } + } + + private[this] def sendRequest(req: FetchRequest): Unit = { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, + Utils.bytesToString(req.size), + req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() + val blockIds = req.blocks.map(_.blockId.toString) + val address = req.address + + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { + blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq))) + deferredBlocks.clear() + } + } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + GlutenShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) + results.put(new SuccessFetchResult( + BlockId(blockId), + infoMap(blockId)._2, + address, + infoMap(blockId)._1, + buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() + } + } + logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + GlutenShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo(s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + results.put(FallbackOnPushMergedFailureResult( + block, + address, + infoMap(blockId)._1, + remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } + } + } + } + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this) + } else { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + null) + } + } + + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + var localBlockBytes = 0L + var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch + + val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. + if (address.host == blockManager.blockManagerId.host) { + numBlocksToFetch += blockInfos.size + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._2).sum + } else { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + } else if (localExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if ( + blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host + ) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + numHostLocalBlocks += blocksForAddress.size + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + val (_, timeCost) = Utils.timeTakenMs[Unit] { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + logDebug(s"Collected remote fetch requests for $address in $timeCost ms") + } + } + val (remoteBlockBytes, numRemoteBlocks) = + collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert( + blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks $numHostLocalBlocks " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks $numRemoteBlocks " + ) + logInfo(s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap(infos => infos.map(info => (info._1, info._3))) + collectedRemoteRequests + } + + private def createFetchRequest( + blocks: Seq[FetchBlockInfo], + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { + logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + + s"with ${blocks.size} blocks") + FetchRequest(address, blocks, forMergedMetas) + } + + private def createFetchRequests( + curBlocks: Seq[FetchBlockInfo], + address: BlockManagerId, + isLast: Boolean, + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) + numBlocksToFetch += mergedBlocks.size + val retBlocks = new ArrayBuffer[FetchBlockInfo] + if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) + } else { + mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { + blocks => + if (blocks.length == maxBlocksInFlightPerAddress || isLast) { + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) + } else { + // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back + // to `curBlocks`. + retBlocks ++= blocks + numBlocksToFetch -= blocks.size + } + } + } + retBlocks + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _, _) => + if ( + curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress + ) { + curBlocks = createFetchRequests( + curBlocks.toSeq, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleMergedBlockId(_, _, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests( + curBlocks.toSeq, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false, + forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests( + curBlocks.toSeq, + address, + isLast = false, + collectedRemoteRequests, + doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests( + curBlocks.toSeq, + address, + isLast = true, + collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, + forMergedMetas = forMergedMetas) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) + } else if (blockSize == 0) { + throw BlockException(blockId, "Zero-sized blocks should be excluded.") + } + } + + private def checkBlockSizes(blockInfos: Seq[(BlockId, Long, Int)]): Unit = { + blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) } + } + + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchLocalBlocks( + localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") + val iter = localBlocks.iterator + while (iter.hasNext) { + val (blockId, mapIndex) = iter.next() + try { + val buf = blockManager.getLocalBlockData(blockId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put(new SuccessFetchResult( + blockId, + mapIndex, + blockManager.blockManagerId, + buf.size(), + buf, + false)) + } catch { + // If we see an exception, stop immediately. + case e: Exception => + e match { + // ClosedByInterruptException is an excepted exception when kill task, + // don't log the exception stack trace to avoid confusing users. + // See: SPARK-28340 + case ce: ClosedByInterruptException => + logError("Error occurred while fetching local blocks, " + ce.getMessage) + case ex: Exception => logError("Error occurred while fetching local blocks", ex) + } + results.put(new FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) + return + } + } + } + + private[this] def fetchHostLocalBlock( + blockId: BlockId, + mapIndex: Int, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) + buf.retain() + results.put(SuccessFetchResult( + blockId, + mapIndex, + blockManagerId, + buf.size(), + buf, + isNetworkReqDone = false)) + true + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + false + } + } + + /** + * Fetch the host-local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]) + : Unit = { + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { + val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { + case (hostLocalBmId, _) => + cachedDirsByExec.contains(hostLocalBmId.executorId) + } + (hasCache.toMap, noCache.toMap) + } + + if (hostLocalBlocksWithMissingDirs.nonEmpty) { + logDebug(s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${hostLocalBlocksWithMissingDirs.mkString(", ")}") + + // If the external shuffle service is enabled, we'll fetch the local directories for + // multiple executors from the external shuffle service, which located at the same host + // with the executors, in once. Otherwise, we'll fetch the local directories from those + // executors directly one by one. The fetch requests won't be too much since one host is + // almost impossible to have many executors at the same time practically. + val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) { + val host = blockManager.blockManagerId.host + val port = blockManager.externalShuffleServicePort + Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray)) + } else { + hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq + } + + dirFetchRequests.foreach { + case (host, port, bmIds) => + hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) { + case Success(dirsByExecId) => + fetchMultipleHostLocalBlocks( + hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap, + dirsByExecId, + cached = false) + + case Failure(throwable) => + logError("Error occurred while fetching host local blocks", throwable) + val bmId = bmIds.head + val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId) + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable)) + } + } + } + + if (hostLocalBlocksWithCachedDirs.nonEmpty) { + logDebug(s"Synchronous fetching host-local blocks with cached executors' dir: " + + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") + fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true) + } + } + + private def fetchMultipleHostLocalBlocks( + bmIdToBlocks: Map[BlockManagerId, Seq[(BlockId, Long, Int)]], + localDirsByExecId: Map[String, Array[String]], + cached: Boolean): Unit = { + // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put + // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the + // remaining blocks. + val allFetchSucceeded = bmIdToBlocks.forall { + case (bmId, blockInfos) => + blockInfos.forall { + case (blockId, _, mapIndex) => + fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId) + } + } + if (allFetchSucceeded) { + logDebug(s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " + + s"(${if (cached) "with" else "without"} cached executors' dir) " + + s"in ${Utils.getUsedTimeNs(startTimeNs)}") + } + } + + private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(onCompleteCallback) + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, + localBlocks, + hostLocalBlocksByExecutor, + pushMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + assert( + (0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight + ) + + // Send out initial requests for blocks, up to our maxBytesInFlight + fetchUpToMaxBytes() + + val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum + val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest + logInfo(s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" + + (if (numDeferredRequest > 0) s", deferred $numDeferredRequest requests" else "")) + + // Get Local Blocks + fetchLocalBlocks(localBlocks) + logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) + } + + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]) + : Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) + } + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers underlying each + * InputStream will be freed by the cleanup() method registered with the TaskCompletionListener. + * However, callers should close() these InputStreams as soon as they are no longer needed, in + * order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw SparkCoreErrors.noSuchElementError() + } + + numBlocksProcessed += 1 + + var result: FetchResult = null + var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null + var streamCompressedOrEncrypted: Boolean = false + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.nanoTime() + result = results.take() + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + shuffleMetrics.incFetchWaitTime(fetchWaitTime) + + result match { + case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + if ( + hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) + ) { + // It is a host local block or a local shuffle chunk + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + bytesInFlight -= size + } + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) + logDebug("Number of requests in flight " + reqsInFlight) + } + + val in = if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + if (blockId.isShuffleChunk) { + // Zero-size block may come from nodes with hardware failures, For shuffle chunks, + // the original shuffle blocks that belong to that zero-size shuffle chunk is + // available and we can opt to fallback immediately. + logWarning(msg) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) + } + } else { + try { + val bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn + } else { + bufIn + } + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + e match { + case ce: ClosedByInterruptException => + logError("Failed to create input stream from local block, " + + ce.getMessage) + case e: IOException => + logError("Failed to create input stream from local block", e) + } + buf.release() + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get + // either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } + } + + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + + if (blockId.isShuffleChunk) { + // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle + // Retrying a corrupt block may result again in a corrupt block. For shuffle + // chunks, we opt to fallback on the original shuffle blocks that belong to that + // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt + // chunk. This also makes the code simpler because the chunkMeta corresponding to + // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk + // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to + // be added back to the chunksMetaMap. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop. + result = null + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + logError(diagnosisResponse) + throwFetchFailedException( + blockId, + mapIndex, + address, + e, + Some(diagnosisResponse)) + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, + Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } finally { + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId]) + } + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() + } + } + } + + case FailureFetchResult(blockId, mapIndex, address, e) => + var errorMsg: String = null + if (e.isInstanceOf[OutOfDirectMemoryError]) { + errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + + s"retries due to Netty OOM" + logError(errorMsg) + } + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + + case DeferFetchRequestResult(request) => + val address = request.address + numBlocksInFlightPerAddress(address) = + numBlocksInFlightPerAddress(address) - request.blocks.size + bytesInFlight -= request.size + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + val defReqQueue = + deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + result = null + + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => + // We get this result in 3 cases: + // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the + // blockId is a ShuffleBlockChunkId. + // 2. Failure to read the push-merged-local meta. In this case, the blockId is + // ShuffleBlockId. + // 3. Failure to get the push-merged-local directories from the external shuffle service. + // In this case, the blockId is ShuffleBlockId. + if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + bytesInFlight -= size + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either + // a SuccessFetchResult or a FailureFetchResult. + result = null + + case PushMergedLocalMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + bitmaps, + localDirs) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) + try { + val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData( + shuffleBlockId, + localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + bufs.zipWithIndex.foreach { + case (buf, chunkId) => + buf.retain() + val shuffleChunkId = ShuffleBlockChunkId( + shuffleId, + shuffleMergeId, + reduceId, + chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put(SuccessFetchResult( + shuffleChunkId, + SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, + buf.size(), + buf, + isNetworkReqDone = false)) + } + } catch { + case e: Exception => + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while reading push-merged-local index, " + + s"prepare to fetch the original blocks", + e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, + pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null + + case PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps, + address) => + // The original meta request is processed so we decrease numBlocksToFetch and + // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the + // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + numBlocksToFetch -= 1 + val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps) + val additionalRemoteReqs = new ArrayBuffer[FetchRequest] + collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) + fetchRequests ++= additionalRemoteReqs + // Set result to null to force another iteration. + result = null + + case PushMergedRemoteMetaFailedFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + address) => + // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + // If we fail to fetch the meta of a push-merged block, we fall back to fetching the + // original blocks. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), + address) + // Set result to null to force another iteration. + result = null + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + } + + val successResult = result.asInstanceOf[SuccessFetchResult] + val threadId = Thread.currentThread().getId + currentResults.put(threadId, successResult) + ( + successResult.blockId, + new GlutenBufferReleasingInputStream( + input, + this, + successResult.blockId, + successResult.mapIndex, + successResult.address, + detectCorrupt && streamCompressedOrEncrypted, + successResult.isNetworkReqDone, + Option(checkedIn) + )) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked when + * checksum is enabled and corruption was detected at least once. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the checksum + * of the block. Then, it will raise a synchronized RPC call along with the checksum to ask the + * server(where the corrupted block is fetched from) to diagnose the cause of corruption and + * return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn + * the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address + * the address where the corrupted block is fetched from. + * @param blockId + * the blockId of the corrupted block. + * @return + * The corruption diagnosis response for different causes. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): String = { + logInfo("Start corruption diagnosis.") + blockId match { + case shuffleBlock: ShuffleBlockId => + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption( + address.host, + address.port, + address.executorId, + shuffleBlock.shuffleId, + shuffleBlock.mapId, + shuffleBlock.reduceId, + checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse + case shuffleBlockChunk: ShuffleBlockChunkId => + // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle + val diagnosisResponse = s"BlockChunk $shuffleBlockChunk is corrupted but corruption " + + s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." + logWarning(diagnosisResponse) + diagnosisResponse + case unexpected: BlockId => + throw new IllegalArgumentException(s"Unexpected type of BlockId, $unexpected") + } + } + + override def onComplete(): Unit = { + onCompleteCallback.onComplete(context) + } + + private def fetchUpToMaxBytes(): Unit = { + if (isNettyOOMOnShuffle.get()) { + if (reqsInFlight > 0) { + // Return immediately if Netty is still OOMed and there're ongoing fetch requests + return + } else { + resetNettyOOMFlagIfPossible(0) + } + } + + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while ( + isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front) + ) { + val request = defReqQueue.dequeue() + logDebug(s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + if (request.forMergedMetas) { + pushBasedFetchHelper.sendFetchMergedStatusRequest(request) + } else { + sendRequest(request) + } + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress + } + } + + private[storage] def throwFetchFailedException( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable, + message: Option[String] = None) = { + val msg = message.getOrElse(e.getMessage) + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, mapIndex, reduceId, msg, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + mapId, + mapIndex, + startReduceId, + msg, + e) + case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e) + } + } + + /** + * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator + */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure related to a push-merged block or shuffle chunk. This is executed by the task thread + * when the `iterator.next()` is invoked and if that initiates fallback. + */ + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode( + originalBlocksByAddr, + originalLocalBlocks, + originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + // fetch all the fallback blocks that are local. + fetchLocalBlocks(originalLocalBlocks) + // Merged local blocks should be empty during fallback + assert( + originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. This is executed by the task thread when the + * `iterator.next()` is invoked and if that initiates fallback. + * + * @return + * set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleReducePartition(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll { + req => + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { + _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } Review Comment: `removePendingChunks` builds `removedChunkIds` inside a `foreach`, but each iteration appends the full `fetchRequestsToRemove` list again. This is unnecessarily O(n²) work in a failure path and can inflate fallback latency. ########## cpp/velox/shuffle/VeloxGpuShuffleReader.cc: ########## @@ -62,62 +78,126 @@ VeloxGpuHashShuffleReaderDeserializer::VeloxGpuHashShuffleReaderDeserializer( rowType_(rowType), readerBufferSize_(readerBufferSize), memoryManager_(memoryManager), + threadPool_(threadPool), deserializeTime_(deserializeTime), decompressTime_(decompressTime) {} -bool VeloxGpuHashShuffleReaderDeserializer::resolveNextBlockType() { - GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(in_.get())); - switch (blockType) { - case BlockType::kEndOfStream: - return false; - case BlockType::kPlainPayload: - return true; - default: - throw GlutenException(fmt::format("Unsupported block type: {}", static_cast<int32_t>(blockType))); +VeloxGpuHashShuffleReaderDeserializer::~VeloxGpuHashShuffleReaderDeserializer() { + // Wait for all reader threads to complete before destroying + if (!isStopped()) { + stop(); } + + decompressTime_ += decompressTimeCounter_.load(std::memory_order_relaxed); + deserializeTime_ += deserializeTimeCounter_.load(std::memory_order_relaxed); } -void VeloxGpuHashShuffleReaderDeserializer::loadNextStream() { - if (reachedEos_) { - return; +std::unique_ptr<ColumnarBatchIterator> VeloxGpuHashShuffleReaderDeserializer::deserializeStreams(int32_t priority) { + batchQueue_ = std::make_unique<CachedBatchQueue<GpuBufferColumnarBatch>>(1L << 30); + + if (!threadPool_) { + throw GlutenException("Thread pool must be provided to VeloxGpuHashShuffleReaderDeserializer"); + } + + const size_t numThreads = threadPool_->getNumThreads(); + activeReaders_.store(numThreads); + + // Submit reader tasks to the thread pool. + std::vector<ReaderThreadPool::Task> tasks; + tasks.reserve(numThreads); + for (size_t i = 0; i < numThreads; ++i) { + tasks.emplace_back([this]() { read(); }); } + threadPool_->submitBatch(std::move(tasks), priority); - auto in = streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool()); - if (in == nullptr) { - reachedEos_ = true; - return; + if (priority == 0) { + threadPool_->start(); } - GLUTEN_ASSIGN_OR_THROW( - in_, - arrow::io::BufferedInputStream::Create( - readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), std::move(in))); + return std::make_unique<AsyncShuffleReaderIterator<GpuBufferColumnarBatch>>(batchQueue_.get()); } -std::shared_ptr<ColumnarBatch> VeloxGpuHashShuffleReaderDeserializer::next() { - if (in_ == nullptr) { - loadNextStream(); +void VeloxGpuHashShuffleReaderDeserializer::stop() { + // Signal threads to stop if not already stopped. + stop_.store(true, std::memory_order_release); + // Wait for all reader threads to complete. + std::unique_lock<std::mutex> lock(completionMtx_); Review Comment: `stop()` waits for reader threads to exit, but reader threads may be blocked inside `batchQueue_->put(batch)` when the queue is full. In that case `stop()` can deadlock because blocked producers never reach the loop top to observe `stop_`. A robust fix needs a way to unblock/abort `put()` when stopping (e.g., close the queue and have `put()` wake and fail). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
