Github user davies commented on a diff in the pull request:
https://github.com/apache/spark/pull/15923#discussion_r90085604
--- Diff:
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
---
@@ -305,40 +312,82 @@ final class ShuffleBlockFetcherIterator(
*/
override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
- val startFetchWait = System.currentTimeMillis()
- currentResult = results.take()
- val result = currentResult
- val stopFetchWait = System.currentTimeMillis()
- shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
-
- result match {
- case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
- if (address != blockManager.blockManagerId) {
- shuffleMetrics.incRemoteBytesRead(buf.size)
- shuffleMetrics.incRemoteBlocksFetched(1)
- }
- bytesInFlight -= size
- if (isNetworkReqDone) {
- reqsInFlight -= 1
- logDebug("Number of requests in flight " + reqsInFlight)
- }
- case _ =>
- }
- // Send fetch requests up to maxBytesInFlight
- fetchUpToMaxBytes()
- result match {
- case FailureFetchResult(blockId, address, e) =>
- throwFetchFailedException(blockId, address, e)
+ var result: FetchResult = null
+ var input: InputStream = null
+ // Take the next fetched result and try to decompress it to detect
data corruption,
+ // then fetch it one more time if it's corrupt, throw
FailureFetchResult if the second fetch
+ // is also corrupt, so the previous stage could be retried.
+ // For local shuffle block, throw FailureFetchResult for the first
IOException.
+ while (result == null) {
+ val startFetchWait = System.currentTimeMillis()
+ result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
- case SuccessFetchResult(blockId, address, _, buf, _) =>
- try {
- (result.blockId, new
BufferReleasingInputStream(buf.createInputStream(), this))
- } catch {
- case NonFatal(t) =>
- throwFetchFailedException(blockId, address, t)
- }
+ result match {
+ case r @ SuccessFetchResult(blockId, address, size, buf,
isNetworkReqDone) =>
+ if (address != blockManager.blockManagerId) {
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ shuffleMetrics.incRemoteBlocksFetched(1)
+ }
+ bytesInFlight -= size
+ if (isNetworkReqDone) {
+ reqsInFlight -= 1
+ logDebug("Number of requests in flight " + reqsInFlight)
+ }
+
+ val in = try {
+ buf.createInputStream()
+ } catch {
+ // The exception could only be throwed by local shuffle block
+ case e: IOException =>
+ assert(buf.isInstanceOf[FileSegmentManagedBuffer])
+ logError("Failed to create input stream from local block", e)
+ buf.release()
+ throwFetchFailedException(blockId, address, e)
+ }
+
+ input = streamWrapper(blockId, in)
+ // Only copy the stream if it's wrapped by compression or
encryption, also the size of
+ // block is small (the decompressed block is smaller than
maxBytesInFlight)
+ if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight /
3) {
+ val out = new ChunkedByteBufferOutputStream(64 * 1024,
ByteBuffer.allocate)
+ try {
+ // Decompress the whole block at once to detect any
corruption, which could increase
+ // the memory usage tne potential increase the chance of OOM.
+ // TODO: manage the memory used here, and spill it into disk
in case of OOM.
+ Utils.copyStream(input, out)
--- End diff --
done
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]