squito commented on a change in pull request #23453: [SPARK-26089][CORE] Handle
corruption in large shuffle blocks
URL: https://github.com/apache/spark/pull/23453#discussion_r253971399
##########
File path:
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -337,9 +342,26 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
}
- private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = {
+ private def mockCorruptBuffer(size: Long = 1L, corruptInStart: Boolean =
true): ManagedBuffer = {
val corruptStream = mock(classOf[InputStream])
- when(corruptStream.read(any(), any(), any())).thenThrow(new
IOException("corrupt"))
+ if (size < 8 * 1024 || corruptInStart) {
+ when(corruptStream.read(any(), any(), any())).thenThrow(new
IOException("corrupt"))
+ } else {
+ when(corruptStream.read(any(), any(), any(classOf[Int]))).thenAnswer(new
Answer[Int] {
+ override def answer(invocationOnMock: InvocationOnMock): Int = {
+ val bufSize = invocationOnMock.getArguments()(2).asInstanceOf[Int]
+ // This condition is needed as we don't throw exception until we
read the stream
+ // less than maxBytesInFlight/3
+ if (bufSize < 8 * 1024) {
+ return bufSize
+ } else {
+ throw new IOException("corrupt")
+ }
+ }
+ })
+ when(corruptStream.read()).thenThrow(new IOException("corrupt"))
+ when(corruptStream.read(any())).thenThrow(new IOException("corrupt"))
Review comment:
this isn't really doing what you want -- with these versions of the call,
you're throwing the exception regardless of the position. And even above,
you're throwing based on size of the buffer, not the position in the stream.
Its probably easiest to just implement your own InputStream rather than using a
mock at this point
```scala
private def mockCorruptBufferAndStream(
size: Long = 1L,
corruptInStart: Boolean): (ManagedBuffer, CorruptStream) = {
val corruptAt = if (corruptInStart) 0 else (8 * 1024)
val corruptStream = new CorruptStream(corruptAt)
val corruptBuffer = mock(classOf[ManagedBuffer])
when(corruptBuffer.size()).thenReturn(size)
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
(corruptBuffer, corruptStream)
}
private def mockCorruptBuffer(size: Long = 1L, corruptInStart: Boolean =
true): ManagedBuffer = {
mockCorruptBufferAndStream(size, corruptInStart)._1
}
class CorruptStream(corruptAt: Long = 0L) extends InputStream {
var pos = 0
var closed = false
override def read(): Int = {
if (pos >= corruptAt) {
throw new IOException("corrupt")
} else {
pos += 1
pos
}
}
override def read(dest: Array[Byte], off: Int, len: Int): Int = {
super.read(dest, off, len)
}
override def close(): Unit = { closed = true }
}
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]