Repository: spark
Updated Branches:
  refs/heads/master 939a322c8 -> 08ce18881


http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index c200654..e251660 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -21,15 +21,19 @@ import java.nio.{ByteBuffer, MappedByteBuffer}
 import java.util.Arrays
 import java.util.concurrent.TimeUnit
 
+import org.apache.spark.network.nio.NioBlockTransferService
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.language.implicitConversions
+import scala.language.postfixOps
+
 import akka.actor._
 import akka.pattern.ask
 import akka.util.Timeout
-import org.apache.spark.shuffle.hash.HashShuffleManager
 
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.Matchers.any
-import org.mockito.Mockito.{doAnswer, mock, spy, when}
-import org.mockito.stubbing.Answer
+import org.mockito.Mockito.{mock, when}
 
 import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
 import org.scalatest.concurrent.Eventually._
@@ -38,18 +42,12 @@ import org.scalatest.Matchers
 
 import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
 import org.apache.spark.executor.DataReadMethod
-import org.apache.spark.network.{Message, ConnectionManagerId}
 import org.apache.spark.scheduler.LiveListenerBus
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.shuffle.hash.HashShuffleManager
 import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
 import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, 
Utils}
 
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.Await
-import scala.concurrent.duration._
-import scala.language.implicitConversions
-import scala.language.postfixOps
-import org.apache.spark.shuffle.ShuffleBlockManager
 
 class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
   with PrivateMethodTester {
@@ -74,8 +72,9 @@ class BlockManagerSuite extends FunSuite with Matchers with 
BeforeAndAfter
   def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
 
   private def makeBlockManager(maxMem: Long, name: String = "<driver>"): 
BlockManager = {
-    new BlockManager(name, actorSystem, master, serializer, maxMem, conf, 
securityMgr,
-      mapOutputTracker, shuffleManager)
+    val transfer = new NioBlockTransferService(conf, securityMgr)
+    new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+      mapOutputTracker, shuffleManager, transfer)
   }
 
   before {
@@ -793,8 +792,9 @@ class BlockManagerSuite extends FunSuite with Matchers with 
BeforeAndAfter
 
   test("block store put failure") {
     // Use Java serializer so we can create an unserializable error.
+    val transfer = new NioBlockTransferService(conf, securityMgr)
     store = new BlockManager("<driver>", actorSystem, master, new 
JavaSerializer(conf), 1200, conf,
-      securityMgr, mapOutputTracker, shuffleManager)
+      mapOutputTracker, shuffleManager, transfer)
 
     // The put should fail since a1 is not serializable.
     class UnserializableClass
@@ -1005,109 +1005,6 @@ class BlockManagerSuite extends FunSuite with Matchers 
with BeforeAndAfter
     assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store")
   }
 
-  test("return error message when error occurred in 
BlockManagerWorker#onBlockMessageReceive") {
-    store = new BlockManager("<driver>", actorSystem, master, serializer, 
1200, conf,
-      securityMgr, mapOutputTracker, shuffleManager)
-
-    val worker = spy(new BlockManagerWorker(store))
-    val connManagerId = mock(classOf[ConnectionManagerId])
-
-    // setup request block messages
-    val reqBlId1 = ShuffleBlockId(0,0,0)
-    val reqBlId2 = ShuffleBlockId(0,1,0)
-    val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
-    val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
-    val reqBlockMessages = new BlockMessageArray(
-      Seq(reqBlockMessage1, reqBlockMessage2))
-    val reqBufferMessage = reqBlockMessages.toBufferMessage
-
-    val answer = new Answer[Option[BlockMessage]] {
-      override def answer(invocation: InvocationOnMock)
-          :Option[BlockMessage]= {
-        throw new Exception
-      }
-    }
-
-    doAnswer(answer).when(worker).processBlockMessage(any())
-
-    // Test when exception was thrown during processing block messages
-    var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, 
connManagerId)
-    
-    assert(ackMessage.isDefined, "When Exception was thrown in " +
-      "BlockManagerWorker#processBlockMessage, " +
-      "ackMessage should be defined")
-    assert(ackMessage.get.hasError, "When Exception was thown in " +
-      "BlockManagerWorker#processBlockMessage, " +
-      "ackMessage should have error")
-
-    val notBufferMessage = mock(classOf[Message])
-
-    // Test when not BufferMessage was received
-    ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId)
-    assert(ackMessage.isDefined, "When not BufferMessage was passed to " +
-      "BlockManagerWorker#onBlockMessageReceive, " +
-      "ackMessage should be defined")
-    assert(ackMessage.get.hasError, "When not BufferMessage was passed to " +
-      "BlockManagerWorker#onBlockMessageReceive, " +
-      "ackMessage should have error")
-  }
-
-  test("return ack message when no error occurred in 
BlocManagerWorker#onBlockMessageReceive") {
-    store = new BlockManager("<driver>", actorSystem, master, serializer, 
1200, conf,
-      securityMgr, mapOutputTracker, shuffleManager)
-
-    val worker = spy(new BlockManagerWorker(store))
-    val connManagerId = mock(classOf[ConnectionManagerId])
-
-    // setup request block messages
-    val reqBlId1 = ShuffleBlockId(0,0,0)
-    val reqBlId2 = ShuffleBlockId(0,1,0)
-    val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
-    val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
-    val reqBlockMessages = new BlockMessageArray(
-      Seq(reqBlockMessage1, reqBlockMessage2))
-
-    val tmpBufferMessage = reqBlockMessages.toBufferMessage
-    val buffer = ByteBuffer.allocate(tmpBufferMessage.size)
-    val arrayBuffer = new ArrayBuffer[ByteBuffer]
-    tmpBufferMessage.buffers.foreach{ b =>
-      buffer.put(b)
-    }
-    buffer.flip()
-    arrayBuffer += buffer
-    val reqBufferMessage = Message.createBufferMessage(arrayBuffer)
-
-    // setup ack block messages
-    val buf1 = ByteBuffer.allocate(4)
-    val buf2 = ByteBuffer.allocate(4)
-    buf1.putInt(1)
-    buf1.flip()
-    buf2.putInt(1)
-    buf2.flip()
-    val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1))
-    val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2))
-
-    val answer = new Answer[Option[BlockMessage]] {
-      override def answer(invocation: InvocationOnMock)
-          :Option[BlockMessage]= {
-        if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq(
-          reqBlockMessage1)) {
-          return Some(ackBlockMessage1)
-        } else {
-          return Some(ackBlockMessage2)
-        }
-      }
-    }
-
-    doAnswer(answer).when(worker).processBlockMessage(any())
-
-    val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, 
connManagerId)
-    assert(ackMessage.isDefined, "When 
BlockManagerWorker#onBlockMessageReceive " +
-      "was executed successfully, ackMessage should be defined")
-    assert(!ackMessage.get.hasError, "When 
BlockManagerWorker#onBlockMessageReceive " +
-      "was executed successfully, ackMessage should not have error")
-  }
-
   test("reserve/release unroll memory") {
     store = makeBlockManager(12000)
     val memoryStore = store.memoryStore

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 26082de..e4522e0 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
 
 import java.io.{File, FileWriter}
 
+import org.apache.spark.network.nio.NioBlockTransferService
 import org.apache.spark.shuffle.hash.HashShuffleManager
 
 import scala.collection.mutable
@@ -52,7 +53,6 @@ class DiskBlockManagerSuite extends FunSuite with 
BeforeAndAfterEach with Before
     rootDir1 = Files.createTempDir()
     rootDir1.deleteOnExit()
     rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
-    println("Created root dirs: " + rootDirs)
   }
 
   override def afterAll() {

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
new file mode 100644
index 0000000..809bd70
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.TaskContext
+import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
+
+import org.mockito.Mockito._
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.scalatest.FunSuite
+
+
+class ShuffleBlockFetcherIteratorSuite extends FunSuite {
+
+  test("handle local read failures in BlockManager") {
+    val transfer = mock(classOf[BlockTransferService])
+    val blockManager = mock(classOf[BlockManager])
+    doReturn(BlockManagerId("test-client", "test-client", 
1)).when(blockManager).blockManagerId
+
+    val blIds = Array[BlockId](
+      ShuffleBlockId(0,0,0),
+      ShuffleBlockId(0,1,0),
+      ShuffleBlockId(0,2,0),
+      ShuffleBlockId(0,3,0),
+      ShuffleBlockId(0,4,0))
+
+    val optItr = mock(classOf[Option[Iterator[Any]]])
+    val answer = new Answer[Option[Iterator[Any]]] {
+      override def answer(invocation: InvocationOnMock) = 
Option[Iterator[Any]] {
+        throw new Exception
+      }
+    }
+
+    // 3rd block is going to fail
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), 
any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), 
any())
+    doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), 
any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), 
any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), 
any())
+
+    val bmId = BlockManagerId("test-client", "test-client", 1)
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+    )
+
+    val iterator = new ShuffleBlockFetcherIterator(
+      new TaskContext(0, 0, 0),
+      transfer,
+      blockManager,
+      blocksByAddress,
+      null,
+      48 * 1024 * 1024)
+
+    // Without exhausting the iterator, the iterator should be lazy and not 
call
+    // getLocalShuffleFromDisk.
+    verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
+
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 
no elements")
+    // the 2nd element of the tuple returned by iterator.next should be 
defined when
+    // fetching successfully
+    assert(iterator.next()._2.isDefined,
+      "1st element should be defined but is not actually defined")
+    verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any())
+
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 
1 element")
+    assert(iterator.next()._2.isDefined,
+      "2nd element should be defined but is not actually defined")
+    verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any())
+
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 
2 elements")
+    // 3rd fetch should be failed
+    intercept[Exception] {
+      iterator.next()
+    }
+    verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any())
+  }
+
+  test("handle local read successes") {
+    val transfer = mock(classOf[BlockTransferService])
+    val blockManager = mock(classOf[BlockManager])
+    doReturn(BlockManagerId("test-client", "test-client", 
1)).when(blockManager).blockManagerId
+
+    val blIds = Array[BlockId](
+      ShuffleBlockId(0,0,0),
+      ShuffleBlockId(0,1,0),
+      ShuffleBlockId(0,2,0),
+      ShuffleBlockId(0,3,0),
+      ShuffleBlockId(0,4,0))
+
+    val optItr = mock(classOf[Option[Iterator[Any]]])
+
+    // All blocks should be fetched successfully
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), 
any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), 
any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), 
any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), 
any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), 
any())
+
+    val bmId = BlockManagerId("test-client", "test-client", 1)
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+    )
+
+    val iterator = new ShuffleBlockFetcherIterator(
+      new TaskContext(0, 0, 0),
+      transfer,
+      blockManager,
+      blocksByAddress,
+      null,
+      48 * 1024 * 1024)
+
+    // Without exhausting the iterator, the iterator should be lazy and not 
call getLocalShuffleFromDisk.
+    verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
+
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 
no elements")
+    assert(iterator.next()._2.isDefined,
+      "All elements should be defined but 1st element is not actually defined")
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 
1 element")
+    assert(iterator.next()._2.isDefined,
+      "All elements should be defined but 2nd element is not actually defined")
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 
2 elements")
+    assert(iterator.next()._2.isDefined,
+      "All elements should be defined but 3rd element is not actually defined")
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 
3 elements")
+    assert(iterator.next()._2.isDefined,
+      "All elements should be defined but 4th element is not actually defined")
+    assert(iterator.hasNext, "iterator should have 5 elements but actually has 
4 elements")
+    assert(iterator.next()._2.isDefined,
+      "All elements should be defined but 5th element is not actually defined")
+
+    verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any())
+  }
+
+  test("handle remote fetch failures in BlockTransferService") {
+    val transfer = mock(classOf[BlockTransferService])
+    when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new 
Answer[Unit] {
+      override def answer(invocation: InvocationOnMock): Unit = {
+        val listener = 
invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+        listener.onBlockFetchFailure(new Exception("blah"))
+      }
+    })
+
+    val blockManager = mock(classOf[BlockManager])
+
+    when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", 
"test-client", 1))
+
+    val blId1 = ShuffleBlockId(0, 0, 0)
+    val blId2 = ShuffleBlockId(0, 1, 0)
+    val bmId = BlockManagerId("test-server", "test-server", 1)
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (bmId, Seq((blId1, 1L), (blId2, 1L))))
+
+    val iterator = new ShuffleBlockFetcherIterator(
+      new TaskContext(0, 0, 0),
+      transfer,
+      blockManager,
+      blocksByAddress,
+      null,
+      48 * 1024 * 1024)
+
+    iterator.foreach { case (_, iterOption) =>
+      assert(!iterOption.isDefined)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to