Github user rezasafi commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22325#discussion_r219021183
  
    --- Diff: 
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala 
---
    @@ -444,36 +445,36 @@ final class ShuffleBlockFetcherIterator(
                   throwFetchFailedException(blockId, address, e)
               }
     
    -          input = streamWrapper(blockId, in)
    -          // Only copy the stream if it's wrapped by compression or 
encryption, also the size of
    -          // block is small (the decompressed block is smaller than 
maxBytesInFlight)
    -          if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 
3) {
    -            val originalInput = input
    -            val out = new ChunkedByteBufferOutputStream(64 * 1024, 
ByteBuffer.allocate)
    -            try {
    +          try {
    +            input = streamWrapper(blockId, in)
    +            originalInput = input
    +            // Only copy the stream if it's wrapped by compression or 
encryption, also the size of
    +            // block is small (the decompressed block is smaller than 
maxBytesInFlight)
    +            if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight 
/ 3) {
    +              val out = new ChunkedByteBufferOutputStream(64 * 1024, 
ByteBuffer.allocate)
                   // Decompress the whole block at once to detect any 
corruption, which could increase
                   // the memory usage tne potential increase the chance of OOM.
                   // TODO: manage the memory used here, and spill it into disk 
in case of OOM.
                   Utils.copyStream(input, out)
                   out.close()
                   input = out.toChunkedByteBuffer.toInputStream(dispose = true)
    -            } catch {
    -              case e: IOException =>
    -                buf.release()
    -                if (buf.isInstanceOf[FileSegmentManagedBuffer]
    -                  || corruptedBlocks.contains(blockId)) {
    -                  throwFetchFailedException(blockId, address, e)
    -                } else {
    -                  logWarning(s"got an corrupted block $blockId from 
$address, fetch again", e)
    -                  corruptedBlocks += blockId
    -                  fetchRequests += FetchRequest(address, Array((blockId, 
size)))
    -                  result = null
    -                }
    -            } finally {
    -              // TODO: release the buf here to free memory earlier
    -              originalInput.close()
    -              in.close()
                 }
    +          } catch {
    +            case e: IOException =>
    +              buf.release()
    +              if (buf.isInstanceOf[FileSegmentManagedBuffer]
    +                || corruptedBlocks.contains(blockId)) {
    +                throwFetchFailedException(blockId, address, e)
    +              } else {
    +                logWarning(s"got an corrupted block $blockId from 
$address, fetch again", e)
    +                corruptedBlocks += blockId
    +                fetchRequests += FetchRequest(address, Array((blockId, 
size)))
    +                result = null
    +              }
    +          } finally {
    +            // TODO: release the buf here to free memory earlier
    +            originalInput.close()
    --- End diff --
    
    yeah, your suggestion works, I will update this shortly


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to