squito commented on a change in pull request #23453: [SPARK-26089][CORE] Handle 
corruption in large shuffle blocks
URL: https://github.com/apache/spark/pull/23453#discussion_r254875667
 
 

 ##########
 File path: 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
 ##########
 @@ -449,54 +452,75 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
   }
 
   test("big blocks are also checked for corruption") {
-    val corruptBuffer1 = mockCorruptBuffer(10000L, true)
-
+    val streamLength = 10000L
     val blockManager = mock(classOf[BlockManager])
+
+    // This stream will throw IOException when the first byte is read
+    val localBuffer = mockCorruptBuffer(streamLength, 0)
     val localBmId = BlockManagerId("test-client", "test-client", 1)
     doReturn(localBmId).when(blockManager).blockManagerId
-    doReturn(corruptBuffer1).when(blockManager).getBlockData(ShuffleBlockId(0, 
0, 0))
+    doReturn(localBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 
0))
+    val localShuffleBlockId = ShuffleBlockId(0, 0, 0)
     val localBlockLengths = Seq[Tuple2[BlockId, Long]](
-      ShuffleBlockId(0, 0, 0) -> corruptBuffer1.size()
+      localShuffleBlockId -> localBuffer.size()
     )
 
-    val corruptBuffer2 = mockCorruptBuffer(10000L, false)
+    val streamNotCorruptTill = 8 * 1024
+    // This stream will throw exception after streamNotCorruptTill bytes are 
read
+    val remoteBuffer = mockCorruptBuffer(streamLength, streamNotCorruptTill)
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val remoteShuffleBlockId = ShuffleBlockId(0, 1, 0)
     val remoteBlockLengths = Seq[Tuple2[BlockId, Long]](
-      ShuffleBlockId(0, 1, 0) -> corruptBuffer2.size()
+      remoteShuffleBlockId -> remoteBuffer.size()
     )
 
     val transfer = createMockTransfer(
-      Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer1, ShuffleBlockId(0, 1, 0) 
-> corruptBuffer2))
-
+      Map(localShuffleBlockId -> localBuffer, remoteShuffleBlockId -> 
remoteBuffer))
     val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
       (localBmId, localBlockLengths),
       (remoteBmId, remoteBlockLengths)
     ).toIterator
-
     val taskContext = TaskContext.empty()
+    val maxBytesInFlight = 3 * 1024
     val iterator = new ShuffleBlockFetcherIterator(
       taskContext,
       transfer,
       blockManager,
       blocksByAddress,
-      (_, in) => new LimitedInputStream(in, 10000),
-      2048,
+      (_, in) => new LimitedInputStream(in, streamLength),
+      maxBytesInFlight,
       Int.MaxValue,
       Int.MaxValue,
       Int.MaxValue,
       true,
       true,
       taskContext.taskMetrics.createTempShuffleReadMetrics())
-    // Only one block should be returned which has corruption after 
maxBytesInFlight/3
+
+    // Only one block should be returned which has corruption after 
maxBytesInFlight/3 because the
+    // other block will be re-fetched
     val (id, st) = iterator.next()
-    assert(id === ShuffleBlockId(0, 1, 0))
-    intercept[FetchFailedException] { iterator.next() }
-    // Following will succeed as it reads the first part of the stream which 
is not corrupt
-    st.read(new Array[Byte](8 * 1024), 0, 8 * 1024)
+    assert(id === remoteShuffleBlockId)
+
+    // The other block will throw a FetchFailedException
+    intercept[FetchFailedException] {
+      iterator.next()
+    }
+
+    // Following will succeed as it reads part of the stream which is not 
corrupt. This will read
+    // maxBytesInFlight/3 bytes from first stream and remaining from the 
second stream
+    new DataInputStream(st).readFully(
+      new Array[Byte](streamNotCorruptTill), 0, streamNotCorruptTill)
+
     // Following will fail as it reads the remaining part of the stream which 
is corrupt
     intercept[FetchFailedException] { st.read() }
-    intercept[FetchFailedException] { st.read(new Array[Byte](8 * 1024)) }
-    intercept[FetchFailedException] { st.read(new Array[Byte](8 * 1024), 0, 8 
* 1024) }
+    intercept[FetchFailedException] { st.read(new Array[Byte](1024)) }
+    intercept[FetchFailedException] { st.read(new Array[Byte](1024), 0, 1024) }
+    intercept[FetchFailedException] { st.skip(1024) }
+
+    IOUtils.closeQuietly(st)
 
 Review comment:
   hmm, I was thinking the stream shoudl get closed automatically after the 
fetch failed exception.  Am I thinking about this wrong?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to