This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 186477c [SPARK-35263][TEST] Refactor ShuffleBlockFetcherIteratorSuite
to reduce duplicated code
186477c is described below
commit 186477c60e9cad71434b15fd9e08789740425d59
Author: Erik Krogen <[email protected]>
AuthorDate: Tue May 18 22:37:47 2021 -0500
[SPARK-35263][TEST] Refactor ShuffleBlockFetcherIteratorSuite to reduce
duplicated code
### What changes were proposed in this pull request?
Introduce new shared methods to `ShuffleBlockFetcherIteratorSuite` to
replace copy-pasted code. Use modern, Scala-like Mockito `Answer` syntax.
### Why are the changes needed?
`ShuffleFetcherBlockIteratorSuite` has tons of duplicate code, like
https://github.com/apache/spark/blob/0494dc90af48ce7da0625485a4dc6917a244d580/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala#L172-L185
. It's challenging to tell what the interesting parts are vs. what is just
being set to some default/unused value.
Similarly but not as bad, there are many calls like the following
```
verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(),
any())
when(transfer.fetchBlocks(any(), any(), any(), any(), any(),
any())).thenAnswer ...
```
These changes result in about 10% reduction in both lines and characters in
the file:
```bash
# Before
> wc
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
1063 3950 43201
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
# After
> wc
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
928 3609 39053
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
```
It also helps readability, e.g.:
```
val iterator = createShuffleBlockIteratorWithDefaults(
transfer,
blocksByAddress,
maxBytesInFlight = 1000L
)
```
Now I can clearly tell that `maxBytesInFlight` is the main parameter we're
interested in here.
### Does this PR introduce _any_ user-facing change?
No, test only. There aren't even any behavior changes, just refactoring.
### How was this patch tested?
Unit tests pass.
Closes #32389 from
xkrogen/xkrogen-spark-35263-refactor-shuffleblockfetcheriteratorsuite.
Authored-by: Erik Krogen <[email protected]>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
.../storage/ShuffleBlockFetcherIteratorSuite.scala | 689 ++++++++-------------
1 file changed, 245 insertions(+), 444 deletions(-)
diff --git
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 99c43b1..4be5fae 100644
---
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -27,7 +27,7 @@ import scala.concurrent.Future
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{mock, times, verify, when}
-import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
import org.scalatest.PrivateMethodTester
import org.apache.spark.{SparkFunSuite, TaskContext}
@@ -35,35 +35,44 @@ import org.apache.spark.network._
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer,
ManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener,
DownloadFileManager, ExternalBlockStoreClient}
import org.apache.spark.network.util.LimitedInputStream
-import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.shuffle.{FetchFailedException,
ShuffleReadMetricsReporter}
import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
import org.apache.spark.util.Utils
class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with
PrivateMethodTester {
+ private var transfer: BlockTransferService = _
+
+ override def beforeEach(): Unit = {
+ transfer = mock(classOf[BlockTransferService])
+ }
+
private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value,
Seq.empty: _*)
+ private def answerFetchBlocks(answer: Answer[Unit]): Unit =
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any(),
any())).thenAnswer(answer)
+
+ private def verifyFetchBlocksInvocationCount(expectedCount: Int): Unit =
+ verify(transfer, times(expectedCount)).fetchBlocks(any(), any(), any(),
any(), any(), any())
+
// Some of the tests are quite tricky because we are testing the cleanup
behavior
// in the presence of faults.
- /** Creates a mock [[BlockTransferService]] that returns data from the given
map. */
- private def createMockTransfer(data: Map[BlockId, ManagedBuffer]):
BlockTransferService = {
- val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any(),
any())).thenAnswer(
- (invocation: InvocationOnMock) => {
- val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]]
- val listener =
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-
- for (blockId <- blocks) {
- if (data.contains(BlockId(blockId))) {
- listener.onBlockFetchSuccess(blockId, data(BlockId(blockId)))
- } else {
- listener.onBlockFetchFailure(blockId, new
BlockNotFoundException(blockId))
- }
+ /** Configures `transfer` (mock [[BlockTransferService]]) to return data
from the given map. */
+ private def configureMockTransfer(data: Map[BlockId, ManagedBuffer]): Unit =
{
+ answerFetchBlocks { invocation =>
+ val blocks = invocation.getArgument[Array[String]](3)
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+
+ for (blockId <- blocks) {
+ if (data.contains(BlockId(blockId))) {
+ listener.onBlockFetchSuccess(blockId, data(BlockId(blockId)))
+ } else {
+ listener.onBlockFetchFailure(blockId, new
BlockNotFoundException(blockId))
}
- })
- transfer
+ }
+ }
}
private def createMockBlockManager(): BlockManager = {
@@ -88,10 +97,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite
with PrivateMethodT
when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager))
when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(),
any()))
.thenAnswer { invocation =>
- val completableFuture = invocation.getArguments()(3)
- .asInstanceOf[CompletableFuture[java.util.Map[String,
Array[String]]]]
import scala.collection.JavaConverters._
- completableFuture.complete(hostLocalDirs.asJava)
+ invocation.getArgument[CompletableFuture[java.util.Map[String,
Array[String]]]](3)
+ .complete(hostLocalDirs.asJava)
}
blockManager.hostLocalDirManager = Some(hostLocalDirManager)
@@ -123,6 +131,49 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
verify(wrappedInputStream.invokePrivate(delegateAccess()),
times(1)).close()
}
+ // scalastyle:off argcount
+ private def createShuffleBlockIteratorWithDefaults(
+ blocksByAddress: Map[BlockManagerId, Seq[(BlockId, Long, Int)]],
+ taskContext: Option[TaskContext] = None,
+ streamWrapperLimitSize: Option[Long] = None,
+ blockManager: Option[BlockManager] = None,
+ maxBytesInFlight: Long = Long.MaxValue,
+ maxReqsInFlight: Int = Int.MaxValue,
+ maxBlocksInFlightPerAddress: Int = Int.MaxValue,
+ maxReqSizeShuffleToMem: Int = Int.MaxValue,
+ detectCorrupt: Boolean = true,
+ detectCorruptUseExtraMemory: Boolean = true,
+ shuffleMetrics: Option[ShuffleReadMetricsReporter] = None,
+ doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = {
+ val tContext = taskContext.getOrElse(TaskContext.empty())
+ new ShuffleBlockFetcherIterator(
+ tContext,
+ transfer,
+ blockManager.getOrElse(createMockBlockManager()),
+ blocksByAddress.toIterator,
+ (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in,
_)).getOrElse(in),
+ maxBytesInFlight,
+ maxReqsInFlight,
+ maxBlocksInFlightPerAddress,
+ maxReqSizeShuffleToMem,
+ detectCorrupt,
+ detectCorruptUseExtraMemory,
+
shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()),
+ doBatchFetch)
+ }
+ // scalastyle:on argcount
+
+ /**
+ * Convert a list of block IDs into a list of blocks with metadata, assuming
all blocks have the
+ * same size and index.
+ */
+ private def toBlockList(
+ blockIds: Traversable[BlockId],
+ blockSize: Long,
+ blockMapIndex: Int): Seq[(BlockId, Long, Int)] = {
+ blockIds.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq
+ }
+
test("successful 3 local + 4 host local + 2 remote reads") {
val blockManager = createMockBlockManager()
val localBmId = blockManager.blockManagerId
@@ -142,15 +193,11 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer())
- val transfer = createMockTransfer(remoteBlocks)
+ configureMockTransfer(remoteBlocks)
// Create a block manager running on the same host (host-local)
val hostLocalBmId = BlockManagerId("test-host-local-client-1",
"test-local-host", 3)
- val hostLocalBlocks = Map[BlockId, ManagedBuffer](
- ShuffleBlockId(0, 5, 0) -> createMockManagedBuffer(),
- ShuffleBlockId(0, 6, 0) -> createMockManagedBuffer(),
- ShuffleBlockId(0, 7, 0) -> createMockManagedBuffer(),
- ShuffleBlockId(0, 8, 0) -> createMockManagedBuffer())
+ val hostLocalBlocks = 5.to(8).map(ShuffleBlockId(0, _, 0) ->
createMockManagedBuffer()).toMap
hostLocalBlocks.foreach { case (blockId, buf) =>
doReturn(buf)
@@ -161,28 +208,14 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
// returning local dir for hostLocalBmId
initHostLocalDirManager(blockManager, hostLocalDirs)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (localBmId, localBlocks.keys.map(blockId => (blockId, 1L, 0)).toSeq),
- (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq),
- (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L,
1)).toSeq)
- ).toIterator
-
- val taskContext = TaskContext.empty()
- val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- 48 * 1024 * 1024,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- false,
- metrics,
- false)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(
+ localBmId -> toBlockList(localBlocks.keys, 1L, 0),
+ remoteBmId -> toBlockList(remoteBlocks.keys, 1L, 1),
+ hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1)
+ ),
+ blockManager = Some(blockManager)
+ )
// 3 local blocks fetched in initialization
verify(blockManager, times(3)).getLocalBlockData(any())
@@ -203,7 +236,7 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
.getHostLocalShuffleData(any(), meq(Array("local-dir")))
// 2 remote blocks are read from the same block manager
- verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(),
any())
+ verifyFetchBlocksInvocationCount(1)
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size
=== 1)
}
@@ -228,117 +261,64 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager))
when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(),
any()))
.thenAnswer { invocation =>
- val completableFuture = invocation.getArguments()(3)
- .asInstanceOf[CompletableFuture[java.util.Map[String,
Array[String]]]]
- completableFuture.completeExceptionally(new Throwable("failed fetch"))
+ invocation.getArgument[CompletableFuture[java.util.Map[String,
Array[String]]]](3)
+ .completeExceptionally(new Throwable("failed fetch"))
}
blockManager.hostLocalDirManager = Some(hostLocalDirManager)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L,
1)).toSeq)
- ).toIterator
- val transfer = createMockTransfer(Map())
- val taskContext = TaskContext.empty()
- val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- 48 * 1024 * 1024,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- false,
- metrics,
- false)
+ configureMockTransfer(Map())
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1))
+ )
intercept[FetchFailedException] { iterator.next() }
}
test("Hit maxBytesInFlight limitation before maxBlocksInFlightPerAddress") {
- val blockManager = createMockBlockManager()
val remoteBmId1 = BlockManagerId("test-remote-client-1",
"test-remote-host1", 1)
val remoteBmId2 = BlockManagerId("test-remote-client-2",
"test-remote-host2", 2)
val blockId1 = ShuffleBlockId(0, 1, 0)
val blockId2 = ShuffleBlockId(1, 1, 0)
- val blocksByAddress = Seq(
- (remoteBmId1, Seq((blockId1, 1000L, 0))),
- (remoteBmId2, Seq((blockId2, 1000L, 0)))).toIterator
- val transfer = createMockTransfer(Map(
+ configureMockTransfer(Map(
blockId1 -> createMockManagedBuffer(1000),
blockId2 -> createMockManagedBuffer(1000)))
- val taskContext = TaskContext.empty()
- val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- 1000L, // allow 1 FetchRequests at most at the same time
- Int.MaxValue,
- Int.MaxValue, // set maxBlocksInFlightPerAddress to Int.MaxValue
- Int.MaxValue,
- true,
- false,
- metrics,
- false)
+ val iterator = createShuffleBlockIteratorWithDefaults(Map(
+ remoteBmId1 -> toBlockList(Seq(blockId1), 1000L, 0),
+ remoteBmId2 -> toBlockList(Seq(blockId2), 1000L, 0)
+ ), maxBytesInFlight = 1000L)
// After initialize() we'll have 2 FetchRequests and each is 1000 bytes.
So only the
// first FetchRequests can be sent, and the second one will hit
maxBytesInFlight so
// it won't be sent.
- verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(),
any())
+ verifyFetchBlocksInvocationCount(1)
assert(iterator.hasNext)
// next() will trigger off sending deferred request
iterator.next()
// the second FetchRequest should be sent at this time
- verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(),
any())
+ verifyFetchBlocksInvocationCount(2)
assert(iterator.hasNext)
iterator.next()
assert(!iterator.hasNext)
}
test("Hit maxBlocksInFlightPerAddress limitation before maxBytesInFlight") {
- val blockManager = createMockBlockManager()
val remoteBmId = BlockManagerId("test-remote-client-1",
"test-remote-host", 2)
- val blockId1 = ShuffleBlockId(0, 1, 0)
- val blockId2 = ShuffleBlockId(0, 2, 0)
- val blockId3 = ShuffleBlockId(0, 3, 0)
- val blocksByAddress = Seq((remoteBmId,
- Seq((blockId1, 1000L, 0), (blockId2, 1000L, 0), (blockId3, 1000L,
0)))).toIterator
- val transfer = createMockTransfer(Map(
- blockId1 -> createMockManagedBuffer(),
- blockId2 -> createMockManagedBuffer(),
- blockId3 -> createMockManagedBuffer()))
- val taskContext = TaskContext.empty()
- val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- Int.MaxValue, // set maxBytesInFlight to Int.MaxValue
- Int.MaxValue,
- 2, // set maxBlocksInFlightPerAddress to 2
- Int.MaxValue,
- true,
- false,
- metrics,
- false)
+ val blocks = 1.to(3).map(ShuffleBlockId(0, _, 0))
+ configureMockTransfer(blocks.map(_ -> createMockManagedBuffer()).toMap)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(blocks, 1000L, 0)),
+ maxBlocksInFlightPerAddress = 2
+ )
// After initialize(), we'll have 2 FetchRequests that one has 2 blocks
inside and another one
// has only one block. So only the first FetchRequest can be sent. The
second FetchRequest will
// hit maxBlocksInFlightPerAddress so it won't be sent.
- verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(),
any())
+ verifyFetchBlocksInvocationCount(1)
// the first request packaged 2 blocks, so we also need to
// call next() for 2 times to exhaust the iterator.
assert(iterator.hasNext)
iterator.next()
assert(iterator.hasNext)
iterator.next()
- verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(),
any())
+ verifyFetchBlocksInvocationCount(2)
assert(iterator.hasNext)
iterator.next()
assert(!iterator.hasNext)
@@ -365,7 +345,7 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
ShuffleBlockId(0, 3, 1))
val mergedRemoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockBatchId(0, 3, 0, 2) -> createMockManagedBuffer())
- val transfer = createMockTransfer(mergedRemoteBlocks)
+ configureMockTransfer(mergedRemoteBlocks)
// Create a block manager running on the same host (host-local)
val hostLocalBmId = BlockManagerId("test-host-local-client-1",
"test-local-host", 3)
@@ -386,28 +366,15 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
// returning local dir for hostLocalBmId
initHostLocalDirManager(blockManager, hostLocalDirs)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (localBmId, localBlocks.map(blockId => (blockId, 1L, 0))),
- (remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1))),
- (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L,
1)).toSeq)
- ).toIterator
-
- val taskContext = TaskContext.empty()
- val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- 48 * 1024 * 1024,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- false,
- metrics,
- true)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(
+ localBmId -> toBlockList(localBlocks, 1L, 0),
+ remoteBmId -> toBlockList(remoteBlocks, 1L, 1),
+ hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1)
+ ),
+ blockManager = Some(blockManager),
+ doBatchFetch = true
+ )
// 3 local blocks batch fetched in initialization
verify(blockManager, times(1)).getLocalBlockData(any())
@@ -416,7 +383,7 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
for (i <- 0 until 3) {
assert(iterator.hasNext, s"iterator should have 3 elements but actually
has $i elements")
val (blockId, inputStream) = iterator.next()
- verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(),
any(), any())
+ verifyFetchBlocksInvocationCount(1)
// Make sure we release buffers when a wrapped input stream is closed.
val mockBuf = allBlocks(blockId)
verifyBufferRelease(mockBuf, inputStream)
@@ -430,7 +397,6 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
}
test("fetch continuous blocks in batch should respect maxBytesInFlight") {
- val blockManager = createMockBlockManager()
// Make sure remote blocks would return the merged block
val remoteBmId1 = BlockManagerId("test-client-1", "test-client-1", 1)
val remoteBmId2 = BlockManagerId("test-client-2", "test-client-2", 2)
@@ -443,28 +409,16 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
ShuffleBlockBatchId(0, 3, 9, 12) -> createMockManagedBuffer(),
ShuffleBlockBatchId(0, 3, 12, 15) -> createMockManagedBuffer(),
ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer())
- val transfer = createMockTransfer(mergedRemoteBlocks)
-
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (remoteBmId1, remoteBlocks1.map(blockId => (blockId, 100L, 1))),
- (remoteBmId2, remoteBlocks2.map(blockId => (blockId, 100L,
1)))).toIterator
-
- val taskContext = TaskContext.empty()
- val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- 1500,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- false,
- metrics,
- true)
+ configureMockTransfer(mergedRemoteBlocks)
+
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(
+ remoteBmId1 -> toBlockList(remoteBlocks1, 100L, 1),
+ remoteBmId2 -> toBlockList(remoteBlocks2, 100L, 1)
+ ),
+ maxBytesInFlight = 1500,
+ doBatchFetch = true
+ )
var numResults = 0
// After initialize(), there will be 6 FetchRequests. And each of the
first 5 requests
@@ -472,7 +426,7 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
// block which merged from 2 shuffle blocks. So, only the first 5
requests(5 * 3 * 100 >= 1500)
// can be sent. The 6th FetchRequest will hit maxBlocksInFlightPerAddress
so it won't
// be sent.
- verify(transfer, times(5)).fetchBlocks(any(), any(), any(), any(), any(),
any())
+ verifyFetchBlocksInvocationCount(5)
while (iterator.hasNext) {
val (blockId, inputStream) = iterator.next()
// Make sure we release buffers when a wrapped input stream is closed.
@@ -481,12 +435,11 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
numResults += 1
}
// The 6th request will be sent after next() is called.
- verify(transfer, times(6)).fetchBlocks(any(), any(), any(), any(), any(),
any())
+ verifyFetchBlocksInvocationCount(6)
assert(numResults == 6)
}
test("fetch continuous blocks in batch should respect
maxBlocksInFlightPerAddress") {
- val blockManager = createMockBlockManager()
// Make sure remote blocks would return the merged block
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 1)
val remoteBlocks = Seq(
@@ -500,31 +453,18 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer(),
ShuffleBlockBatchId(0, 5, 0, 1) -> createMockManagedBuffer())
- val transfer = createMockTransfer(mergedRemoteBlocks)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (remoteBmId, remoteBlocks.map(blockId => (blockId, 100L, 1)))).toIterator
- val taskContext = TaskContext.empty()
- val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- Int.MaxValue,
- Int.MaxValue,
- 2,
- Int.MaxValue,
- true,
- false,
- metrics,
- true)
+ configureMockTransfer(mergedRemoteBlocks)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(remoteBlocks, 100L, 1)),
+ maxBlocksInFlightPerAddress = 2,
+ doBatchFetch = true
+ )
var numResults = 0
// After initialize(), there will be 2 FetchRequests. First one has 2
merged blocks and each
// of them is merged from 2 shuffle blocks, second one has 1 merged block
which is merged from
// 1 shuffle block. So only the first FetchRequest can be sent. The second
FetchRequest will
// hit maxBlocksInFlightPerAddress so it won't be sent.
- verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(),
any())
+ verifyFetchBlocksInvocationCount(1)
while (iterator.hasNext) {
val (blockId, inputStream) = iterator.next()
// Make sure we release buffers when a wrapped input stream is closed.
@@ -533,12 +473,11 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
numResults += 1
}
// The second request will be sent after next() is called.
- verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(),
any())
+ verifyFetchBlocksInvocationCount(2)
assert(numResults == 3)
}
test("release current unexhausted buffer in case the task completes early") {
- val blockManager = createMockBlockManager()
// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
@@ -549,40 +488,25 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)
- val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
- .thenAnswer((invocation: InvocationOnMock) => {
- val listener =
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
- Future {
- // Return the first two blocks, and wait till task completion before
returning the 3rd one
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
- sem.acquire()
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
- }
- })
-
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1L,
0)).toSeq)).toIterator
+ answerFetchBlocks { invocation =>
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+ Future {
+ // Return the first two blocks, and wait till task completion before
returning the 3rd one
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
+ sem.acquire()
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
+ }
+ }
val taskContext = TaskContext.empty()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- 48 * 1024 * 1024,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- false,
- taskContext.taskMetrics.createTempShuffleReadMetrics(),
- false)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+ taskContext = Some(taskContext)
+ )
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
iterator.next()._2.close() // close() first block's input stream
@@ -603,7 +527,6 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
}
test("fail all blocks if any of the remote request fails") {
- val blockManager = createMockBlockManager()
// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
@@ -615,41 +538,23 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)
- val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
- .thenAnswer((invocation: InvocationOnMock) => {
- val listener =
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
- Future {
- // Return the first block, and then fail.
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
- listener.onBlockFetchFailure(
- ShuffleBlockId(0, 1, 0).toString, new
BlockNotFoundException("blah"))
+ answerFetchBlocks { invocation =>
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+ Future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah"))
listener.onBlockFetchFailure(
ShuffleBlockId(0, 2, 0).toString, new
BlockNotFoundException("blah"))
- sem.release()
- }
- })
-
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq))
- .toIterator
+ sem.release()
+ }
+ }
- val taskContext = TaskContext.empty()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- 48 * 1024 * 1024,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- false,
- taskContext.taskMetrics.createTempShuffleReadMetrics(),
- false)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0))
+ )
// Continue only after the mock calls onBlockFetchFailure
sem.acquire()
@@ -690,7 +595,6 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
}
test("retry corrupt blocks") {
- val blockManager = createMockBlockManager()
// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
@@ -703,40 +607,24 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
val sem = new Semaphore(0)
val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"),
0, 100)
- val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
- .thenAnswer((invocation: InvocationOnMock) => {
- val listener =
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
- Future {
- // Return the first block, and then fail.
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
- sem.release()
- }
- })
-
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1L,
0)).toSeq)).toIterator
+ answerFetchBlocks { invocation =>
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+ Future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
+ sem.release()
+ }
+ }
- val taskContext = TaskContext.empty()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => new LimitedInputStream(in, 100),
- 48 * 1024 * 1024,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- true,
- taskContext.taskMetrics.createTempShuffleReadMetrics(),
- false)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+ streamWrapperLimitSize = Some(100)
+ )
// Continue only after the mock calls onBlockFetchFailure
sem.acquire()
@@ -745,16 +633,14 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
val (id1, _) = iterator.next()
assert(id1 === ShuffleBlockId(0, 0, 0))
- when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
- .thenAnswer((invocation: InvocationOnMock) => {
- val listener =
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
- Future {
- // Return the first block, and then fail.
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
- sem.release()
- }
- })
+ answerFetchBlocks { invocation =>
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+ Future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(ShuffleBlockId(0, 1, 0).toString,
mockCorruptBuffer())
+ sem.release()
+ }
+ }
// The next block is corrupt local block (the second one is corrupt and
retried)
intercept[FetchFailedException] { iterator.next() }
@@ -765,47 +651,28 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
test("big blocks are also checked for corruption") {
val streamLength = 10000L
- val blockManager = createMockBlockManager()
// This stream will throw IOException when the first byte is read
val corruptBuffer1 = mockCorruptBuffer(streamLength, 0)
val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1",
1)
val shuffleBlockId1 = ShuffleBlockId(0, 1, 0)
- val blockLengths1 = Seq[Tuple3[BlockId, Long, Int]](
- (shuffleBlockId1, corruptBuffer1.size(), 1)
- )
val streamNotCorruptTill = 8 * 1024
// This stream will throw exception after streamNotCorruptTill bytes are
read
val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill)
val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2",
2)
val shuffleBlockId2 = ShuffleBlockId(0, 2, 0)
- val blockLengths2 = Seq[Tuple3[BlockId, Long, Int]](
- (shuffleBlockId2, corruptBuffer2.size(), 2)
- )
- val transfer = createMockTransfer(
+ configureMockTransfer(
Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 ->
corruptBuffer2))
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (blockManagerId1, blockLengths1),
- (blockManagerId2, blockLengths2)
- ).toIterator
- val taskContext = TaskContext.empty()
- val maxBytesInFlight = 3 * 1024
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => new LimitedInputStream(in, streamLength),
- maxBytesInFlight,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- true,
- taskContext.taskMetrics.createTempShuffleReadMetrics(),
- false)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(
+ blockManagerId1 -> toBlockList(Seq(shuffleBlockId1),
corruptBuffer1.size(), 1),
+ blockManagerId2 -> toBlockList(Seq(shuffleBlockId2),
corruptBuffer2.size(), 2)
+ ),
+ streamWrapperLimitSize = Some(streamLength),
+ maxBytesInFlight = 3 * 1024
+ )
// We'll get back the block which has corruption after maxBytesInFlight/3
because the other
// block will detect corruption on first fetch, and then get added to the
queue again for
@@ -848,30 +715,15 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId
doReturn(managedBuffer).when(blockManager).getLocalBlockData(meq(ShuffleBlockId(0,
0, 0)))
- val localBlockLengths = Seq[Tuple3[BlockId, Long, Int]](
- (ShuffleBlockId(0, 0, 0), 10000, 0)
- )
- val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) ->
managedBuffer))
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (localBmId, localBlockLengths)
- ).toIterator
+ configureMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer))
- val taskContext = TaskContext.empty()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => new LimitedInputStream(in, 10000),
- 2048,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- true,
- taskContext.taskMetrics.createTempShuffleReadMetrics(),
- false)
- val (id, st) = iterator.next()
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(localBmId -> toBlockList(Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0)),
+ blockManager = Some(blockManager),
+ streamWrapperLimitSize = Some(10000),
+ maxBytesInFlight = 2048 // force concatenation of stream by limiting
bytes in flight
+ )
+ val (_, st) = iterator.next()
// Check that the test setup is correct -- make sure we have a
concatenated stream.
assert
(st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream])
@@ -884,7 +736,6 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
}
test("retry corrupt blocks (disabled)") {
- val blockManager = createMockBlockManager()
// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
@@ -896,41 +747,25 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)
- val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
- .thenAnswer((invocation: InvocationOnMock) => {
- val listener =
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
- Future {
- // Return the first block, and then fail.
- listener.onBlockFetchSuccess(
+ answerFetchBlocks { invocation =>
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+ Future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer())
- sem.release()
- }
- })
-
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq))
- .toIterator
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer())
+ sem.release()
+ }
+ }
- val taskContext = TaskContext.empty()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => new LimitedInputStream(in, 100),
- 48 * 1024 * 1024,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- false,
- taskContext.taskMetrics.createTempShuffleReadMetrics(),
- false)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+ streamWrapperLimitSize = Some(100),
+ detectCorruptUseExtraMemory = false
+ )
// Continue only after the mock calls onBlockFetchFailure
sem.acquire()
@@ -958,57 +793,38 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val remoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
- val transfer = mock(classOf[BlockTransferService])
var tempFileManager: DownloadFileManager = null
- when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
- .thenAnswer((invocation: InvocationOnMock) => {
- val listener =
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
- tempFileManager =
invocation.getArguments()(5).asInstanceOf[DownloadFileManager]
- Future {
- listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0,
0, 0)))
- }
- })
+ answerFetchBlocks { invocation =>
+ val listener = invocation.getArgument[BlockFetchingListener](4)
+ tempFileManager = invocation.getArgument[DownloadFileManager](5)
+ Future {
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0,
0)))
+ }
+ }
- def fetchShuffleBlock(
- blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long,
Int)])]): Unit = {
- // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so
that during the
+ def fetchShuffleBlock(blockSize: Long): Unit = {
+ // Use default `maxBytesInFlight` and `maxReqsInFlight` (`Int.MaxValue`)
so that during the
// construction of `ShuffleBlockFetcherIterator`, all requests to fetch
remote shuffle blocks
// are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here.
- val taskContext = TaskContext.empty()
- new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress,
- (_, in) => in,
- maxBytesInFlight = Int.MaxValue,
- maxReqsInFlight = Int.MaxValue,
- maxBlocksInFlightPerAddress = Int.MaxValue,
- maxReqSizeShuffleToMem = 200,
- detectCorrupt = true,
- false,
- taskContext.taskMetrics.createTempShuffleReadMetrics(),
- false)
+ createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(remoteBlocks.keys, blockSize, 0)),
+ blockManager = Some(blockManager),
+ maxReqSizeShuffleToMem = 200)
}
- val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L,
0)).toSeq)).toIterator
- fetchShuffleBlock(blocksByAddress1)
+ fetchShuffleBlock(100L)
// `maxReqSizeShuffleToMem` is 200, which is greater than the block size
100, so don't fetch
// shuffle block to disk.
assert(tempFileManager == null)
- val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L,
0)).toSeq)).toIterator
- fetchShuffleBlock(blocksByAddress2)
+ fetchShuffleBlock(300L)
// `maxReqSizeShuffleToMem` is 200, which is smaller than the block size
300, so fetch
// shuffle block to disk.
assert(tempFileManager != null)
}
test("fail zero-size blocks") {
- val blockManager = createMockBlockManager()
// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
@@ -1016,26 +832,11 @@ class ShuffleBlockFetcherIteratorSuite extends
SparkFunSuite with PrivateMethodT
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()
)
- val transfer = createMockTransfer(blocks.mapValues(_ =>
createMockManagedBuffer(0)).toMap)
-
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq))
+ configureMockTransfer(blocks.mapValues(_ =>
createMockManagedBuffer(0)).toMap)
- val taskContext = TaskContext.empty()
- val iterator = new ShuffleBlockFetcherIterator(
- taskContext,
- transfer,
- blockManager,
- blocksByAddress.toIterator,
- (_, in) => in,
- 48 * 1024 * 1024,
- Int.MaxValue,
- Int.MaxValue,
- Int.MaxValue,
- true,
- false,
- taskContext.taskMetrics.createTempShuffleReadMetrics(),
- false)
+ val iterator = createShuffleBlockIteratorWithDefaults(
+ Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0))
+ )
// All blocks fetched return zero length and should trigger a receive-side
error:
val e = intercept[FetchFailedException] { iterator.next() }
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]