http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala 
b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
new file mode 100644
index 0000000..09d3ea3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -0,0 +1,1042 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.io.IOException
+import java.net._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
+import java.util.{Timer, TimerTask}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, 
SynchronizedMap, SynchronizedQueue}
+import scala.concurrent.duration._
+import scala.concurrent.{Await, ExecutionContext, Future, Promise}
+import scala.language.postfixOps
+
+import org.apache.spark._
+import org.apache.spark.util.{SystemClock, Utils}
+
+
+private[nio] class ConnectionManager(
+    port: Int,
+    conf: SparkConf,
+    securityManager: SecurityManager,
+    name: String = "Connection manager")
+  extends Logging {
+
+  /**
+   * Used by sendMessageReliably to track messages being sent.
+   * @param message the message that was sent
+   * @param connectionManagerId the connection manager that sent this message
+   * @param completionHandler callback that's invoked when the send has 
completed or failed
+   */
+  class MessageStatus(
+      val message: Message,
+      val connectionManagerId: ConnectionManagerId,
+      completionHandler: MessageStatus => Unit) {
+
+    /** This is non-None if message has been ack'd */
+    var ackMessage: Option[Message] = None
+
+    def markDone(ackMessage: Option[Message]) {
+      this.ackMessage = ackMessage
+      completionHandler(this)
+    }
+  }
+
+  private val selector = SelectorProvider.provider.openSelector()
+  private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
+
+  // default to 30 second timeout waiting for authentication
+  private val authTimeout = 
conf.getInt("spark.core.connection.auth.wait.timeout", 30)
+  private val ackTimeout = 
conf.getInt("spark.core.connection.ack.wait.timeout", 60)
+
+  private val handleMessageExecutor = new ThreadPoolExecutor(
+    conf.getInt("spark.core.connection.handler.threads.min", 20),
+    conf.getInt("spark.core.connection.handler.threads.max", 60),
+    conf.getInt("spark.core.connection.handler.threads.keepalive", 60), 
TimeUnit.SECONDS,
+    new LinkedBlockingDeque[Runnable](),
+    Utils.namedThreadFactory("handle-message-executor"))
+
+  private val handleReadWriteExecutor = new ThreadPoolExecutor(
+    conf.getInt("spark.core.connection.io.threads.min", 4),
+    conf.getInt("spark.core.connection.io.threads.max", 32),
+    conf.getInt("spark.core.connection.io.threads.keepalive", 60), 
TimeUnit.SECONDS,
+    new LinkedBlockingDeque[Runnable](),
+    Utils.namedThreadFactory("handle-read-write-executor"))
+
+  // Use a different, yet smaller, thread pool - infrequently used with very 
short lived tasks :
+  // which should be executed asap
+  private val handleConnectExecutor = new ThreadPoolExecutor(
+    conf.getInt("spark.core.connection.connect.threads.min", 1),
+    conf.getInt("spark.core.connection.connect.threads.max", 8),
+    conf.getInt("spark.core.connection.connect.threads.keepalive", 60), 
TimeUnit.SECONDS,
+    new LinkedBlockingDeque[Runnable](),
+    Utils.namedThreadFactory("handle-connect-executor"))
+
+  private val serverChannel = ServerSocketChannel.open()
+  // used to track the SendingConnections waiting to do SASL negotiation
+  private val connectionsAwaitingSasl = new HashMap[ConnectionId, 
SendingConnection]
+    with SynchronizedMap[ConnectionId, SendingConnection]
+  private val connectionsByKey =
+    new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, 
Connection]
+  private val connectionsById = new HashMap[ConnectionManagerId, 
SendingConnection]
+    with SynchronizedMap[ConnectionManagerId, SendingConnection]
+  private val messageStatuses = new HashMap[Int, MessageStatus]
+  private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, 
Int)]
+  private val registerRequests = new SynchronizedQueue[SendingConnection]
+
+  implicit val futureExecContext = ExecutionContext.fromExecutor(
+    Utils.newDaemonCachedThreadPool("Connection manager future execution 
context"))
+
+  @volatile
+  private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => 
Option[Message] = null
+
+  private val authEnabled = securityManager.isAuthenticationEnabled()
+
+  serverChannel.configureBlocking(false)
+  serverChannel.socket.setReuseAddress(true)
+  serverChannel.socket.setReceiveBufferSize(256 * 1024)
+
+  private def startService(port: Int): (ServerSocketChannel, Int) = {
+    serverChannel.socket.bind(new InetSocketAddress(port))
+    (serverChannel, serverChannel.socket.getLocalPort)
+  }
+  Utils.startServiceOnPort[ServerSocketChannel](port, startService, name)
+  serverChannel.register(selector, SelectionKey.OP_ACCEPT)
+
+  val id = new ConnectionManagerId(Utils.localHostName, 
serverChannel.socket.getLocalPort)
+  logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " 
with id = " + id)
+
+  // used in combination with the ConnectionManagerId to create unique 
Connection ids
+  // to be able to track asynchronous messages
+  private val idCount: AtomicInteger = new AtomicInteger(1)
+
+  private val selectorThread = new Thread("connection-manager-thread") {
+    override def run() = ConnectionManager.this.run()
+  }
+  selectorThread.setDaemon(true)
+  selectorThread.start()
+
+  private val writeRunnableStarted: HashSet[SelectionKey] = new 
HashSet[SelectionKey]()
+
+  private def triggerWrite(key: SelectionKey) {
+    val conn = connectionsByKey.getOrElse(key, null)
+    if (conn == null) return
+
+    writeRunnableStarted.synchronized {
+      // So that we do not trigger more write events while processing this one.
+      // The write method will re-register when done.
+      if (conn.changeInterestForWrite()) conn.unregisterInterest()
+      if (writeRunnableStarted.contains(key)) {
+        // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
+        return
+      }
+
+      writeRunnableStarted += key
+    }
+    handleReadWriteExecutor.execute(new Runnable {
+      override def run() {
+        var register: Boolean = false
+        try {
+          register = conn.write()
+        } finally {
+          writeRunnableStarted.synchronized {
+            writeRunnableStarted -= key
+            val needReregister = register || conn.resetForceReregister()
+            if (needReregister && conn.changeInterestForWrite()) {
+              conn.registerInterest()
+            }
+          }
+        }
+      }
+    } )
+  }
+
+  private val readRunnableStarted: HashSet[SelectionKey] = new 
HashSet[SelectionKey]()
+
+  private def triggerRead(key: SelectionKey) {
+    val conn = connectionsByKey.getOrElse(key, null)
+    if (conn == null) return
+
+    readRunnableStarted.synchronized {
+      // So that we do not trigger more read events while processing this one.
+      // The read method will re-register when done.
+      if (conn.changeInterestForRead())conn.unregisterInterest()
+      if (readRunnableStarted.contains(key)) {
+        return
+      }
+
+      readRunnableStarted += key
+    }
+    handleReadWriteExecutor.execute(new Runnable {
+      override def run() {
+        var register: Boolean = false
+        try {
+          register = conn.read()
+        } finally {
+          readRunnableStarted.synchronized {
+            readRunnableStarted -= key
+            if (register && conn.changeInterestForRead()) {
+              conn.registerInterest()
+            }
+          }
+        }
+      }
+    } )
+  }
+
+  private def triggerConnect(key: SelectionKey) {
+    val conn = connectionsByKey.getOrElse(key, 
null).asInstanceOf[SendingConnection]
+    if (conn == null) return
+
+    // prevent other events from being triggered
+    // Since we are still trying to connect, we do not need to do the 
additional steps in
+    // triggerWrite
+    conn.changeConnectionKeyInterest(0)
+
+    handleConnectExecutor.execute(new Runnable {
+      override def run() {
+
+        var tries: Int = 10
+        while (tries >= 0) {
+          if (conn.finishConnect(false)) return
+          // Sleep ?
+          Thread.sleep(1)
+          tries -= 1
+        }
+
+        // fallback to previous behavior : we should not really come here 
since this method was
+        // triggered since channel became connectable : but at times, the 
first finishConnect need
+        // not succeed : hence the loop to retry a few 'times'.
+        conn.finishConnect(true)
+      }
+    } )
+  }
+
+  // MUST be called within selector loop - else deadlock.
+  private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
+    try {
+      key.interestOps(0)
+    } catch {
+      // ignore exceptions
+      case e: Exception => logDebug("Ignoring exception", e)
+    }
+
+    val conn = connectionsByKey.getOrElse(key, null)
+    if (conn == null) return
+
+    // Pushing to connect threadpool
+    handleConnectExecutor.execute(new Runnable {
+      override def run() {
+        try {
+          conn.callOnExceptionCallback(e)
+        } catch {
+          // ignore exceptions
+          case e: Exception => logDebug("Ignoring exception", e)
+        }
+        try {
+          conn.close()
+        } catch {
+          // ignore exceptions
+          case e: Exception => logDebug("Ignoring exception", e)
+        }
+      }
+    })
+  }
+
+
+  def run() {
+    try {
+      while(!selectorThread.isInterrupted) {
+        while (!registerRequests.isEmpty) {
+          val conn: SendingConnection = registerRequests.dequeue()
+          addListeners(conn)
+          conn.connect()
+          addConnection(conn)
+        }
+
+        while(!keyInterestChangeRequests.isEmpty) {
+          val (key, ops) = keyInterestChangeRequests.dequeue()
+
+          try {
+            if (key.isValid) {
+              val connection = connectionsByKey.getOrElse(key, null)
+              if (connection != null) {
+                val lastOps = key.interestOps()
+                key.interestOps(ops)
+
+                // hot loop - prevent materialization of string if trace not 
enabled.
+                if (isTraceEnabled()) {
+                  def intToOpStr(op: Int): String = {
+                    val opStrs = ArrayBuffer[String]()
+                    if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+                    if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+                    if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += 
"CONNECT"
+                    if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+                    if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else 
" "
+                  }
+
+                  logTrace("Changed key for connection to [" +
+                    connection.getRemoteConnectionManagerId()  + "] changed 
from [" +
+                      intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+                }
+              }
+            } else {
+              logInfo("Key not valid ? " + key)
+              throw new CancelledKeyException()
+            }
+          } catch {
+            case e: CancelledKeyException => {
+              logInfo("key already cancelled ? " + key, e)
+              triggerForceCloseByException(key, e)
+            }
+            case e: Exception => {
+              logError("Exception processing key " + key, e)
+              triggerForceCloseByException(key, e)
+            }
+          }
+        }
+
+        val selectedKeysCount =
+          try {
+            selector.select()
+          } catch {
+            // Explicitly only dealing with CancelledKeyException here since 
other exceptions
+            // should be dealt with differently.
+            case e: CancelledKeyException => {
+              // Some keys within the selectors list are invalid/closed. clear 
them.
+              val allKeys = selector.keys().iterator()
+
+              while (allKeys.hasNext) {
+                val key = allKeys.next()
+                try {
+                  if (! key.isValid) {
+                    logInfo("Key not valid ? " + key)
+                    throw new CancelledKeyException()
+                  }
+                } catch {
+                  case e: CancelledKeyException => {
+                    logInfo("key already cancelled ? " + key, e)
+                    triggerForceCloseByException(key, e)
+                  }
+                  case e: Exception => {
+                    logError("Exception processing key " + key, e)
+                    triggerForceCloseByException(key, e)
+                  }
+                }
+              }
+            }
+            0
+          }
+
+        if (selectedKeysCount == 0) {
+          logDebug("Selector selected " + selectedKeysCount + " of " + 
selector.keys.size +
+            " keys")
+        }
+        if (selectorThread.isInterrupted) {
+          logInfo("Selector thread was interrupted!")
+          return
+        }
+
+        if (0 != selectedKeysCount) {
+          val selectedKeys = selector.selectedKeys().iterator()
+          while (selectedKeys.hasNext) {
+            val key = selectedKeys.next
+            selectedKeys.remove()
+            try {
+              if (key.isValid) {
+                if (key.isAcceptable) {
+                  acceptConnection(key)
+                } else
+                if (key.isConnectable) {
+                  triggerConnect(key)
+                } else
+                if (key.isReadable) {
+                  triggerRead(key)
+                } else
+                if (key.isWritable) {
+                  triggerWrite(key)
+                }
+              } else {
+                logInfo("Key not valid ? " + key)
+                throw new CancelledKeyException()
+              }
+            } catch {
+              // weird, but we saw this happening - even though key.isValid 
was true,
+              // key.isAcceptable would throw CancelledKeyException.
+              case e: CancelledKeyException => {
+                logInfo("key already cancelled ? " + key, e)
+                triggerForceCloseByException(key, e)
+              }
+              case e: Exception => {
+                logError("Exception processing key " + key, e)
+                triggerForceCloseByException(key, e)
+              }
+            }
+          }
+        }
+      }
+    } catch {
+      case e: Exception => logError("Error in select loop", e)
+    }
+  }
+
+  def acceptConnection(key: SelectionKey) {
+    val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
+
+    var newChannel = serverChannel.accept()
+
+    // accept them all in a tight loop. non blocking accept with no 
processing, should be fine
+    while (newChannel != null) {
+      try {
+        val newConnectionId = new ConnectionId(id, 
idCount.getAndIncrement.intValue)
+        val newConnection = new ReceivingConnection(newChannel, selector, 
newConnectionId)
+        newConnection.onReceive(receiveMessage)
+        addListeners(newConnection)
+        addConnection(newConnection)
+        logInfo("Accepted connection from [" + newConnection.remoteAddress + 
"]")
+      } catch {
+        // might happen in case of issues with registering with selector
+        case e: Exception => logError("Error in accept loop", e)
+      }
+
+      newChannel = serverChannel.accept()
+    }
+  }
+
+  private def addListeners(connection: Connection) {
+    connection.onKeyInterestChange(changeConnectionKeyInterest)
+    connection.onException(handleConnectionError)
+    connection.onClose(removeConnection)
+  }
+
+  def addConnection(connection: Connection) {
+    connectionsByKey += ((connection.key, connection))
+  }
+
+  def removeConnection(connection: Connection) {
+    connectionsByKey -= connection.key
+
+    try {
+      connection match {
+        case sendingConnection: SendingConnection =>
+          val sendingConnectionManagerId = 
sendingConnection.getRemoteConnectionManagerId()
+          logInfo("Removing SendingConnection to " + 
sendingConnectionManagerId)
+
+          connectionsById -= sendingConnectionManagerId
+          connectionsAwaitingSasl -= connection.connectionId
+
+          messageStatuses.synchronized {
+            messageStatuses.values.filter(_.connectionManagerId == 
sendingConnectionManagerId)
+              .foreach(status => {
+                logInfo("Notifying " + status)
+                status.markDone(None)
+              })
+
+            messageStatuses.retain((i, status) => {
+              status.connectionManagerId != sendingConnectionManagerId
+            })
+          }
+        case receivingConnection: ReceivingConnection =>
+          val remoteConnectionManagerId = 
receivingConnection.getRemoteConnectionManagerId()
+          logInfo("Removing ReceivingConnection to " + 
remoteConnectionManagerId)
+
+          val sendingConnectionOpt = 
connectionsById.get(remoteConnectionManagerId)
+          if (!sendingConnectionOpt.isDefined) {
+            logError(s"Corresponding SendingConnection to 
${remoteConnectionManagerId} not found")
+            return
+          }
+
+          val sendingConnection = sendingConnectionOpt.get
+          connectionsById -= remoteConnectionManagerId
+          sendingConnection.close()
+
+          val sendingConnectionManagerId = 
sendingConnection.getRemoteConnectionManagerId()
+
+          assert(sendingConnectionManagerId == remoteConnectionManagerId)
+
+          messageStatuses.synchronized {
+            for (s <- messageStatuses.values
+                 if s.connectionManagerId == sendingConnectionManagerId) {
+              logInfo("Notifying " + s)
+              s.markDone(None)
+            }
+
+            messageStatuses.retain((i, status) => {
+              status.connectionManagerId != sendingConnectionManagerId
+            })
+          }
+        case _ => logError("Unsupported type of connection.")
+      }
+    } finally {
+      // So that the selection keys can be removed.
+      wakeupSelector()
+    }
+  }
+
+  def handleConnectionError(connection: Connection, e: Exception) {
+    logInfo("Handling connection error on connection to " +
+      connection.getRemoteConnectionManagerId())
+    removeConnection(connection)
+  }
+
+  def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+    keyInterestChangeRequests += ((connection.key, ops))
+    // so that registerations happen !
+    wakeupSelector()
+  }
+
+  def receiveMessage(connection: Connection, message: Message) {
+    val connectionManagerId = 
ConnectionManagerId.fromSocketAddress(message.senderAddress)
+    logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
+    val runnable = new Runnable() {
+      val creationTime = System.currentTimeMillis
+      def run() {
+        logDebug("Handler thread delay is " + (System.currentTimeMillis - 
creationTime) + " ms")
+        handleMessage(connectionManagerId, message, connection)
+        logDebug("Handling delay is " + (System.currentTimeMillis - 
creationTime) + " ms")
+      }
+    }
+    handleMessageExecutor.execute(runnable)
+    /* handleMessage(connection, message) */
+  }
+
+  private def handleClientAuthentication(
+      waitingConn: SendingConnection,
+      securityMsg: SecurityMessage,
+      connectionId : ConnectionId) {
+    if (waitingConn.isSaslComplete()) {
+      logDebug("Client sasl completed for id: "  + waitingConn.connectionId)
+      connectionsAwaitingSasl -= waitingConn.connectionId
+      waitingConn.getAuthenticated().synchronized {
+        waitingConn.getAuthenticated().notifyAll()
+      }
+      return
+    } else {
+      var replyToken : Array[Byte] = null
+      try {
+        replyToken = 
waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+        if (waitingConn.isSaslComplete()) {
+          logDebug("Client sasl completed after evaluate for id: " + 
waitingConn.connectionId)
+          connectionsAwaitingSasl -= waitingConn.connectionId
+          waitingConn.getAuthenticated().synchronized {
+            waitingConn.getAuthenticated().notifyAll()
+          }
+          return
+        }
+        val securityMsgResp = SecurityMessage.fromResponse(replyToken,
+          securityMsg.getConnectionId.toString)
+        val message = securityMsgResp.toBufferMessage
+        if (message == null) throw new IOException("Error creating security 
message")
+        sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), 
message)
+      } catch  {
+        case e: Exception => {
+          logError("Error handling sasl client authentication", e)
+          waitingConn.close()
+          throw new IOException("Error evaluating sasl response: ", e)
+        }
+      }
+    }
+  }
+
+  private def handleServerAuthentication(
+      connection: Connection,
+      securityMsg: SecurityMessage,
+      connectionId: ConnectionId) {
+    if (!connection.isSaslComplete()) {
+      logDebug("saslContext not established")
+      var replyToken : Array[Byte] = null
+      try {
+        connection.synchronized {
+          if (connection.sparkSaslServer == null) {
+            logDebug("Creating sasl Server")
+            connection.sparkSaslServer = new SparkSaslServer(securityManager)
+          }
+        }
+        replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
+        if (connection.isSaslComplete()) {
+          logDebug("Server sasl completed: " + connection.connectionId)
+        } else {
+          logDebug("Server sasl not completed: " + connection.connectionId)
+        }
+        if (replyToken != null) {
+          val securityMsgResp = SecurityMessage.fromResponse(replyToken,
+            securityMsg.getConnectionId)
+          val message = securityMsgResp.toBufferMessage
+          if (message == null) throw new Exception("Error creating security 
Message")
+          sendSecurityMessage(connection.getRemoteConnectionManagerId(), 
message)
+        }
+      } catch {
+        case e: Exception => {
+          logError("Error in server auth negotiation: " + e)
+          // It would probably be better to send an error message telling 
other side auth failed
+          // but for now just close
+          connection.close()
+        }
+      }
+    } else {
+      logDebug("connection already established for this connection id: " + 
connection.connectionId)
+    }
+  }
+
+
+  private def handleAuthentication(conn: Connection, bufferMessage: 
BufferMessage): Boolean = {
+    if (bufferMessage.isSecurityNeg) {
+      logDebug("This is security neg message")
+
+      // parse as SecurityMessage
+      val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage)
+      val connectionId = 
ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId)
+
+      connectionsAwaitingSasl.get(connectionId) match {
+        case Some(waitingConn) => {
+          // Client - this must be in response to us doing Send
+          logDebug("Client handleAuth for id: " +  waitingConn.connectionId)
+          handleClientAuthentication(waitingConn, securityMsg, connectionId)
+        }
+        case None => {
+          // Server - someone sent us something and we haven't authenticated 
yet
+          logDebug("Server handleAuth for id: " + connectionId)
+          handleServerAuthentication(conn, securityMsg, connectionId)
+        }
+      }
+      return true
+    } else {
+      if (!conn.isSaslComplete()) {
+        // We could handle this better and tell the client we need to do 
authentication
+        // negotiation, but for now just ignore them.
+        logError("message sent that is not security negotiation message on 
connection " +
+                 "not authenticated yet, ignoring it!!")
+        return true
+      }
+    }
+    false
+  }
+
+  private def handleMessage(
+      connectionManagerId: ConnectionManagerId,
+      message: Message,
+      connection: Connection) {
+    logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
+    message match {
+      case bufferMessage: BufferMessage => {
+        if (authEnabled) {
+          val res = handleAuthentication(connection, bufferMessage)
+          if (res) {
+            // message was security negotiation so skip the rest
+            logDebug("After handleAuth result was true, returning")
+            return
+          }
+        }
+        if (bufferMessage.hasAckId()) {
+          messageStatuses.synchronized {
+            messageStatuses.get(bufferMessage.ackId) match {
+              case Some(status) => {
+                messageStatuses -= bufferMessage.ackId
+                status.markDone(Some(message))
+              }
+              case None => {
+                /**
+                 * We can fall down on this code because of following 2 cases
+                 *
+                 * (1) Invalid ack sent due to buggy code.
+                 *
+                 * (2) Late-arriving ack for a SendMessageStatus
+                 *     To avoid unwilling late-arriving ack
+                 *     caused by long pause like GC, you can set
+                 *     larger value than default to 
spark.core.connection.ack.wait.timeout
+                 */
+                logWarning(s"Could not find reference for received ack Message 
${message.id}")
+              }
+            }
+          }
+        } else {
+          var ackMessage : Option[Message] = None
+          try {
+            ackMessage = if (onReceiveCallback != null) {
+              logDebug("Calling back")
+              onReceiveCallback(bufferMessage, connectionManagerId)
+            } else {
+              logDebug("Not calling back as callback is null")
+              None
+            }
+
+            if (ackMessage.isDefined) {
+              if (!ackMessage.get.isInstanceOf[BufferMessage]) {
+                logDebug("Response to " + bufferMessage + " is not a buffer 
message, it is of type "
+                  + ackMessage.get.getClass)
+              } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) 
{
+                logDebug("Response to " + bufferMessage + " does not have ack 
id set")
+                ackMessage.get.asInstanceOf[BufferMessage].ackId = 
bufferMessage.id
+              }
+            }
+          } catch {
+            case e: Exception => {
+              logError(s"Exception was thrown while processing message", e)
+              val m = Message.createBufferMessage(bufferMessage.id)
+              m.hasError = true
+              ackMessage = Some(m)
+            }
+          } finally {
+            sendMessage(connectionManagerId, ackMessage.getOrElse {
+              Message.createBufferMessage(bufferMessage.id)
+            })
+          }
+        }
+      }
+      case _ => throw new Exception("Unknown type message received")
+    }
+  }
+
+  private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: 
SendingConnection) {
+    // see if we need to do sasl before writing
+    // this should only be the first negotiation as the Client!!!
+    if (!conn.isSaslComplete()) {
+      conn.synchronized {
+        if (conn.sparkSaslClient == null) {
+          conn.sparkSaslClient = new SparkSaslClient(securityManager)
+          var firstResponse: Array[Byte] = null
+          try {
+            firstResponse = conn.sparkSaslClient.firstToken()
+            val securityMsg = SecurityMessage.fromResponse(firstResponse,
+              conn.connectionId.toString())
+            val message = securityMsg.toBufferMessage
+            if (message == null) throw new Exception("Error creating security 
message")
+            connectionsAwaitingSasl += ((conn.connectionId, conn))
+            sendSecurityMessage(connManagerId, message)
+            logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+          } catch {
+            case e: Exception => {
+              logError("Error getting first response from the SaslClient.", e)
+              conn.close()
+              throw new Exception("Error getting first response from the 
SaslClient")
+            }
+          }
+        }
+      }
+    } else {
+      logDebug("Sasl already established ")
+    }
+  }
+
+  // allow us to add messages to the inbox for doing sasl negotiating
+  private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: 
Message) {
+    def startNewConnection(): SendingConnection = {
+      val inetSocketAddress = new InetSocketAddress(connManagerId.host, 
connManagerId.port)
+      val newConnectionId = new ConnectionId(id, 
idCount.getAndIncrement.intValue)
+      val newConnection = new SendingConnection(inetSocketAddress, selector, 
connManagerId,
+        newConnectionId)
+      logInfo("creating new sending connection for security! " + 
newConnectionId )
+      registerRequests.enqueue(newConnection)
+
+      newConnection
+    }
+    // I removed the lookupKey stuff as part of merge ... should I re-add it ?
+    // We did not find it useful in our test-env ...
+    // If we do re-add it, we should consistently use it everywhere I guess ?
+    message.senderAddress = id.toSocketAddress()
+    logTrace("Sending Security [" + message + "] to [" + connManagerId + "]")
+    val connection = connectionsById.getOrElseUpdate(connManagerId, 
startNewConnection())
+
+    // send security message until going connection has been authenticated
+    connection.send(message)
+
+    wakeupSelector()
+  }
+
+  private def sendMessage(connectionManagerId: ConnectionManagerId, message: 
Message) {
+    def startNewConnection(): SendingConnection = {
+      val inetSocketAddress = new InetSocketAddress(connectionManagerId.host,
+        connectionManagerId.port)
+      val newConnectionId = new ConnectionId(id, 
idCount.getAndIncrement.intValue)
+      val newConnection = new SendingConnection(inetSocketAddress, selector, 
connectionManagerId,
+        newConnectionId)
+      logTrace("creating new sending connection: " + newConnectionId)
+      registerRequests.enqueue(newConnection)
+
+      newConnection
+    }
+    val connection = connectionsById.getOrElseUpdate(connectionManagerId, 
startNewConnection())
+    if (authEnabled) {
+      checkSendAuthFirst(connectionManagerId, connection)
+    }
+    message.senderAddress = id.toSocketAddress()
+    logDebug("Before Sending [" + message + "] to [" + connectionManagerId + 
"]" + " " +
+      "connectionid: "  + connection.connectionId)
+
+    if (authEnabled) {
+      // if we aren't authenticated yet lets block the senders until 
authentication completes
+      try {
+        connection.getAuthenticated().synchronized {
+          val clock = SystemClock
+          val startTime = clock.getTime()
+
+          while (!connection.isSaslComplete()) {
+            logDebug("getAuthenticated wait connectionid: " + 
connection.connectionId)
+            // have timeout in case remote side never responds
+            connection.getAuthenticated().wait(500)
+            if (((clock.getTime() - startTime) >= (authTimeout * 1000))
+              && (!connection.isSaslComplete())) {
+              // took to long to authenticate the connection, something 
probably went wrong
+              throw new Exception("Took to long for authentication to " + 
connectionManagerId +
+                ", waited " + authTimeout + "seconds, failing.")
+            }
+          }
+        }
+      } catch {
+        case e: Exception => logError("Exception while waiting for 
authentication.", e)
+
+        // need to tell sender it failed
+        messageStatuses.synchronized {
+          val s = messageStatuses.get(message.id)
+          s match {
+            case Some(msgStatus) => {
+              messageStatuses -= message.id
+              logInfo("Notifying " + msgStatus.connectionManagerId)
+              msgStatus.markDone(None)
+            }
+            case None => {
+              logError("no messageStatus for failed message id: " + message.id)
+            }
+          }
+        }
+      }
+    }
+    logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
+    connection.send(message)
+
+    wakeupSelector()
+  }
+
+  private def wakeupSelector() {
+    selector.wakeup()
+  }
+
+  /**
+   * Send a message and block until an acknowldgment is received or an error 
occurs.
+   * @param connectionManagerId the message's destination
+   * @param message the message being sent
+   * @return a Future that either returns the acknowledgment message or 
captures an exception.
+   */
+  def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: 
Message)
+      : Future[Message] = {
+    val promise = Promise[Message]()
+
+    val timeoutTask = new TimerTask {
+      override def run(): Unit = {
+        messageStatuses.synchronized {
+          messageStatuses.remove(message.id).foreach ( s => {
+            promise.failure(
+              new IOException("sendMessageReliably failed because ack " +
+                s"was not received within $ackTimeout sec"))
+          })
+        }
+      }
+    }
+
+    val status = new MessageStatus(message, connectionManagerId, s => {
+      timeoutTask.cancel()
+      s.ackMessage match {
+        case None => // Indicates a failure where we either never sent or 
never got ACK'd
+          promise.failure(new IOException("sendMessageReliably failed without 
being ACK'd"))
+        case Some(ackMessage) =>
+          if (ackMessage.hasError) {
+            promise.failure(
+              new IOException("sendMessageReliably failed with ACK that 
signalled a remote error"))
+          } else {
+            promise.success(ackMessage)
+          }
+      }
+    })
+    messageStatuses.synchronized {
+      messageStatuses += ((message.id, status))
+    }
+
+    ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
+    sendMessage(connectionManagerId, message)
+    promise.future
+  }
+
+  def onReceiveMessage(callback: (Message, ConnectionManagerId) => 
Option[Message]) {
+    onReceiveCallback = callback
+  }
+
+  def stop() {
+    ackTimeoutMonitor.cancel()
+    selectorThread.interrupt()
+    selectorThread.join()
+    selector.close()
+    val connections = connectionsByKey.values
+    connections.foreach(_.close())
+    if (connectionsByKey.size != 0) {
+      logWarning("All connections not cleaned up")
+    }
+    handleMessageExecutor.shutdown()
+    handleReadWriteExecutor.shutdown()
+    handleConnectExecutor.shutdown()
+    logInfo("ConnectionManager stopped")
+  }
+}
+
+
+private[spark] object ConnectionManager {
+  import scala.concurrent.ExecutionContext.Implicits.global
+
+  def main(args: Array[String]) {
+    val conf = new SparkConf
+    val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
+    manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+      println("Received [" + msg + "] from [" + id + "]")
+      None
+    })
+
+    /* testSequentialSending(manager) */
+    /* System.gc() */
+
+    /* testParallelSending(manager) */
+    /* System.gc() */
+
+    /* testParallelDecreasingSending(manager) */
+    /* System.gc() */
+
+    testContinuousSending(manager)
+    System.gc()
+  }
+
+  def testSequentialSending(manager: ConnectionManager) {
+    println("--------------------------")
+    println("Sequential Sending")
+    println("--------------------------")
+    val size = 10 * 1024 * 1024
+    val count = 10
+
+    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => 
x.toByte))
+    buffer.flip
+
+    (0 until count).map(i => {
+      val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+      Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 
Duration.Inf)
+    })
+    println("--------------------------")
+    println()
+  }
+
+  def testParallelSending(manager: ConnectionManager) {
+    println("--------------------------")
+    println("Parallel Sending")
+    println("--------------------------")
+    val size = 10 * 1024 * 1024
+    val count = 10
+
+    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => 
x.toByte))
+    buffer.flip
+
+    val startTime = System.currentTimeMillis
+    (0 until count).map(i => {
+      val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+      manager.sendMessageReliably(manager.id, bufferMessage)
+    }).foreach(f => {
+      f.onFailure {
+        case e => println("Failed due to " + e)
+      }
+      Await.ready(f, 1 second)
+    })
+    val finishTime = System.currentTimeMillis
+
+    val mb = size * count / 1024.0 / 1024.0
+    val ms = finishTime - startTime
+    val tput = mb * 1000.0 / ms
+    println("--------------------------")
+    println("Started at " + startTime + ", finished at " + finishTime)
+    println("Sent " + count + " messages of size " + size + " in " + ms + " ms 
" +
+      "(" + tput + " MB/s)")
+    println("--------------------------")
+    println()
+  }
+
+  def testParallelDecreasingSending(manager: ConnectionManager) {
+    println("--------------------------")
+    println("Parallel Decreasing Sending")
+    println("--------------------------")
+    val size = 10 * 1024 * 1024
+    val count = 10
+    val buffers = Array.tabulate(count) { i =>
+      val bufferLen = size * (i + 1)
+      val bufferContent = Array.tabulate[Byte](bufferLen)(x => x.toByte)
+      ByteBuffer.allocate(bufferLen).put(bufferContent)
+    }
+    buffers.foreach(_.flip)
+    val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
+
+    val startTime = System.currentTimeMillis
+    (0 until count).map(i => {
+      val bufferMessage = Message.createBufferMessage(buffers(count - 1 - 
i).duplicate)
+      manager.sendMessageReliably(manager.id, bufferMessage)
+    }).foreach(f => {
+      f.onFailure {
+        case e => println("Failed due to " + e)
+      }
+      Await.ready(f, 1 second)
+    })
+    val finishTime = System.currentTimeMillis
+
+    val ms = finishTime - startTime
+    val tput = mb * 1000.0 / ms
+    println("--------------------------")
+    /* println("Started at " + startTime + ", finished at " + finishTime) */
+    println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+    println("--------------------------")
+    println()
+  }
+
+  def testContinuousSending(manager: ConnectionManager) {
+    println("--------------------------")
+    println("Continuous Sending")
+    println("--------------------------")
+    val size = 10 * 1024 * 1024
+    val count = 10
+
+    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => 
x.toByte))
+    buffer.flip
+
+    val startTime = System.currentTimeMillis
+    while(true) {
+      (0 until count).map(i => {
+          val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+          manager.sendMessageReliably(manager.id, bufferMessage)
+        }).foreach(f => {
+          f.onFailure {
+            case e => println("Failed due to " + e)
+          }
+          Await.ready(f, 1 second)
+        })
+      val finishTime = System.currentTimeMillis
+      Thread.sleep(1000)
+      val mb = size * count / 1024.0 / 1024.0
+      val ms = finishTime - startTime
+      val tput = mb * 1000.0 / ms
+      println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+      println("--------------------------")
+      println()
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala 
b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala
new file mode 100644
index 0000000..cbb37ec
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.net.InetSocketAddress
+
+import org.apache.spark.util.Utils
+
+private[nio] case class ConnectionManagerId(host: String, port: Int) {
+  // DEBUG code
+  Utils.checkHost(host)
+  assert (port > 0)
+
+  def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+
+private[nio] object ConnectionManagerId {
+  def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId 
= {
+    new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/Message.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala 
b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
new file mode 100644
index 0000000..0b874c2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[nio] abstract class Message(val typ: Long, val id: Int) {
+  var senderAddress: InetSocketAddress = null
+  var started = false
+  var startTime = -1L
+  var finishTime = -1L
+  var isSecurityNeg = false
+  var hasError = false
+
+  def size: Int
+
+  def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
+
+  def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
+
+  def timeTaken(): String = (finishTime - startTime).toString + " ms"
+
+  override def toString = this.getClass.getSimpleName + "(id = " + id + ", 
size = " + size + ")"
+}
+
+
+private[nio] object Message {
+  val BUFFER_MESSAGE = 1111111111L
+
+  var lastId = 1
+
+  def getNewId() = synchronized {
+    lastId += 1
+    if (lastId == 0) {
+      lastId += 1
+    }
+    lastId
+  }
+
+  def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): 
BufferMessage = {
+    if (dataBuffers == null) {
+      return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
+    }
+    if (dataBuffers.exists(_ == null)) {
+      throw new Exception("Attempting to create buffer message with null 
buffer")
+    }
+    new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, 
ackId)
+  }
+
+  def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
+    createBufferMessage(dataBuffers, 0)
+
+  def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = 
{
+    if (dataBuffer == null) {
+      createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
+    } else {
+      createBufferMessage(Array(dataBuffer), ackId)
+    }
+  }
+
+  def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+    createBufferMessage(dataBuffer, 0)
+
+  def createBufferMessage(ackId: Int): BufferMessage = {
+    createBufferMessage(new Array[ByteBuffer](0), ackId)
+  }
+
+  def create(header: MessageChunkHeader): Message = {
+    val newMessage: Message = header.typ match {
+      case BUFFER_MESSAGE => new BufferMessage(header.id,
+        ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+    }
+    newMessage.hasError = header.hasError
+    newMessage.senderAddress = header.address
+    newMessage
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala 
b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala
new file mode 100644
index 0000000..278c5ac
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+private[nio]
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+
+  val size = if (buffer == null) 0 else buffer.remaining
+
+  lazy val buffers = {
+    val ab = new ArrayBuffer[ByteBuffer]()
+    ab += header.buffer
+    if (buffer != null) {
+      ab += buffer
+    }
+    ab
+  }
+
+  override def toString = {
+    "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + 
size + ")"
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala 
b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala
new file mode 100644
index 0000000..6e20f29
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.net.{InetAddress, InetSocketAddress}
+import java.nio.ByteBuffer
+
+private[nio] class MessageChunkHeader(
+    val typ: Long,
+    val id: Int,
+    val totalSize: Int,
+    val chunkSize: Int,
+    val other: Int,
+    val hasError: Boolean,
+    val securityNeg: Int,
+    val address: InetSocketAddress) {
+  lazy val buffer = {
+    // No need to change this, at 'use' time, we do a reverse lookup of the 
hostname.
+    // Refer to network.Connection
+    val ip = address.getAddress.getAddress()
+    val port = address.getPort()
+    ByteBuffer.
+      allocate(MessageChunkHeader.HEADER_SIZE).
+      putLong(typ).
+      putInt(id).
+      putInt(totalSize).
+      putInt(chunkSize).
+      putInt(other).
+      put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
+      putInt(securityNeg).
+      putInt(ip.size).
+      put(ip).
+      putInt(port).
+      position(MessageChunkHeader.HEADER_SIZE).
+      flip.asInstanceOf[ByteBuffer]
+  }
+
+  override def toString = "" + this.getClass.getSimpleName + ":" + id + " of 
type " + typ +
+      " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " 
+ securityNeg
+
+}
+
+
+private[nio] object MessageChunkHeader {
+  val HEADER_SIZE = 45
+
+  def create(buffer: ByteBuffer): MessageChunkHeader = {
+    if (buffer.remaining != HEADER_SIZE) {
+      throw new IllegalArgumentException("Cannot convert buffer data to 
Message")
+    }
+    val typ = buffer.getLong()
+    val id = buffer.getInt()
+    val totalSize = buffer.getInt()
+    val chunkSize = buffer.getInt()
+    val other = buffer.getInt()
+    val hasError = buffer.get() != 0
+    val securityNeg = buffer.getInt()
+    val ipSize = buffer.getInt()
+    val ipBytes = new Array[Byte](ipSize)
+    buffer.get(ipBytes)
+    val ip = InetAddress.getByAddress(ipBytes)
+    val port = buffer.getInt()
+    new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, 
securityNeg,
+      new InetSocketAddress(ip, port))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
 
b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
new file mode 100644
index 0000000..59958ee
--- /dev/null
+++ 
b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import scala.concurrent.Future
+
+import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
+import org.apache.spark.network._
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+
+/**
+ * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a 
custom
+ * implementation using Java NIO.
+ */
+final class NioBlockTransferService(conf: SparkConf, securityManager: 
SecurityManager)
+  extends BlockTransferService with Logging {
+
+  private var cm: ConnectionManager = _
+
+  private var blockDataManager: BlockDataManager = _
+
+  /**
+   * Port number the service is listening on, available only after [[init]] is 
invoked.
+   */
+  override def port: Int = {
+    checkInit()
+    cm.id.port
+  }
+
+  /**
+   * Host name the service is listening on, available only after [[init]] is 
invoked.
+   */
+  override def hostName: String = {
+    checkInit()
+    cm.id.host
+  }
+
+  /**
+   * Initialize the transfer service by giving it the BlockDataManager that 
can be used to fetch
+   * local blocks or put local blocks.
+   */
+  override def init(blockDataManager: BlockDataManager): Unit = {
+    this.blockDataManager = blockDataManager
+    cm = new ConnectionManager(
+      conf.getInt("spark.blockManager.port", 0),
+      conf,
+      securityManager,
+      "Connection manager for block manager")
+    cm.onReceiveMessage(onBlockMessageReceive)
+  }
+
+  /**
+   * Tear down the transfer service.
+   */
+  override def stop(): Unit = {
+    if (cm != null) {
+      cm.stop()
+    }
+  }
+
+  override def fetchBlocks(
+      hostName: String,
+      port: Int,
+      blockIds: Seq[String],
+      listener: BlockFetchingListener): Unit = {
+    checkInit()
+
+    val cmId = new ConnectionManagerId(hostName, port)
+    val blockMessageArray = new BlockMessageArray(blockIds.map { blockId =>
+      BlockMessage.fromGetBlock(GetBlock(BlockId(blockId)))
+    })
+
+    val future = cm.sendMessageReliably(cmId, 
blockMessageArray.toBufferMessage)
+
+    // Register the listener on success/failure future callback.
+    future.onSuccess { case message =>
+      val bufferMessage = message.asInstanceOf[BufferMessage]
+      val blockMessageArray = 
BlockMessageArray.fromBufferMessage(bufferMessage)
+
+      for (blockMessage <- blockMessageArray) {
+        if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+          listener.onBlockFetchFailure(
+            new SparkException(s"Unexpected message ${blockMessage.getType} 
received from $cmId"))
+        } else {
+          val blockId = blockMessage.getId
+          val networkSize = blockMessage.getData.limit()
+          listener.onBlockFetchSuccess(
+            blockId.toString, new 
NioByteBufferManagedBuffer(blockMessage.getData))
+        }
+      }
+    }(cm.futureExecContext)
+
+    future.onFailure { case exception =>
+      listener.onBlockFetchFailure(exception)
+    }(cm.futureExecContext)
+  }
+
+  /**
+   * Upload a single block to a remote node, available only after [[init]] is 
invoked.
+   *
+   * This call blocks until the upload completes, or throws an exception upon 
failures.
+   */
+  override def uploadBlock(
+      hostname: String,
+      port: Int,
+      blockId: String,
+      blockData: ManagedBuffer,
+      level: StorageLevel)
+    : Future[Unit] = {
+    checkInit()
+    val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level)
+    val blockMessageArray = new 
BlockMessageArray(BlockMessage.fromPutBlock(msg))
+    val remoteCmId = new ConnectionManagerId(hostName, port)
+    val reply = cm.sendMessageReliably(remoteCmId, 
blockMessageArray.toBufferMessage)
+    reply.map(x => ())(cm.futureExecContext)
+  }
+
+  private def checkInit(): Unit = if (cm == null) {
+    throw new IllegalStateException(getClass.getName + " has not been 
initialized")
+  }
+
+  private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): 
Option[Message] = {
+    logDebug("Handling message " + msg)
+    msg match {
+      case bufferMessage: BufferMessage =>
+        try {
+          logDebug("Handling as a buffer message " + bufferMessage)
+          val blockMessages = 
BlockMessageArray.fromBufferMessage(bufferMessage)
+          logDebug("Parsed as a block message array")
+          val responseMessages = 
blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
+          Some(new BlockMessageArray(responseMessages).toBufferMessage)
+        } catch {
+          case e: Exception => {
+            logError("Exception handling buffer message", e)
+            val errorMessage = Message.createBufferMessage(msg.id)
+            errorMessage.hasError = true
+            Some(errorMessage)
+          }
+        }
+
+      case otherMessage: Any =>
+        logError("Unknown type message received: " + otherMessage)
+        val errorMessage = Message.createBufferMessage(msg.id)
+        errorMessage.hasError = true
+        Some(errorMessage)
+    }
+  }
+
+  private def processBlockMessage(blockMessage: BlockMessage): 
Option[BlockMessage] = {
+    blockMessage.getType match {
+      case BlockMessage.TYPE_PUT_BLOCK =>
+        val msg = PutBlock(blockMessage.getId, blockMessage.getData, 
blockMessage.getLevel)
+        logDebug("Received [" + msg + "]")
+        putBlock(msg.id.toString, msg.data, msg.level)
+        None
+
+      case BlockMessage.TYPE_GET_BLOCK =>
+        val msg = new GetBlock(blockMessage.getId)
+        logDebug("Received [" + msg + "]")
+        val buffer = getBlock(msg.id.toString)
+        if (buffer == null) {
+          return None
+        }
+        Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer)))
+
+      case _ => None
+    }
+  }
+
+  private def putBlock(blockId: String, bytes: ByteBuffer, level: 
StorageLevel) {
+    val startTimeMs = System.currentTimeMillis()
+    logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with 
data: " + bytes)
+    blockDataManager.putBlockData(blockId, new 
NioByteBufferManagedBuffer(bytes), level)
+    logDebug("PutBlock " + blockId + " used " + 
Utils.getUsedTimeMs(startTimeMs)
+      + " with data size: " + bytes.limit)
+  }
+
+  private def getBlock(blockId: String): ByteBuffer = {
+    val startTimeMs = System.currentTimeMillis()
+    logDebug("GetBlock " + blockId + " started from " + startTimeMs)
+    val buffer = blockDataManager.getBlockData(blockId).orNull
+    logDebug("GetBlock " + blockId + " used " + 
Utils.getUsedTimeMs(startTimeMs)
+      + " and got buffer " + buffer)
+    buffer.nioByteBuffer()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala 
b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala
new file mode 100644
index 0000000..747a208
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.nio
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.{ArrayBuffer, StringBuilder}
+
+import org.apache.spark._
+
+/**
+ * SecurityMessage is class that contains the connectionId and sasl token
+ * used in SASL negotiation. SecurityMessage has routines for converting
+ * it to and from a BufferMessage so that it can be sent by the 
ConnectionManager
+ * and easily consumed by users when received.
+ * The api was modeled after BlockMessage.
+ *
+ * The connectionId is the connectionId of the client side. Since
+ * message passing is asynchronous and its possible for the server side 
(receiving)
+ * to get multiple different types of messages on the same connection the 
connectionId
+ * is used to know which connnection the security message is intended for.
+ *
+ * For instance, lets say we are node_0. We need to send data to node_1. The 
node_0 side
+ * is acting as a client and connecting to node_1. SASL negotiation has to 
occur
+ * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a 
security message.
+ * node_1 receives the message from node_0 but before it can process it and 
send a response,
+ * some thread on node_1 decides it needs to send data to node_0 so it 
connects to node_0
+ * and sends a security message of its own to authenticate as a client. Now 
node_0 gets
+ * the message and it needs to decide if this message is in response to it 
being a client
+ * (from the first send) or if its just node_1 trying to connect to it to send 
data.  This
+ * is where the connectionId field is used. node_0 can lookup the connectionId 
to see if
+ * it is in response to it being a client or if its in response to someone 
sending other data.
+ *
+ * The format of a SecurityMessage as its sent is:
+ *   - Length of the ConnectionId
+ *   - ConnectionId
+ *   - Length of the token
+ *   - Token
+ */
+private[nio] class SecurityMessage extends Logging {
+
+  private var connectionId: String = null
+  private var token: Array[Byte] = null
+
+  def set(byteArr: Array[Byte], newconnectionId: String) {
+    if (byteArr == null) {
+      token = new Array[Byte](0)
+    } else {
+      token = byteArr
+    }
+    connectionId = newconnectionId
+  }
+
+  /**
+   * Read the given buffer and set the members of this class.
+   */
+  def set(buffer: ByteBuffer) {
+    val idLength = buffer.getInt()
+    val idBuilder = new StringBuilder(idLength)
+    for (i <- 1 to idLength) {
+        idBuilder += buffer.getChar()
+    }
+    connectionId  = idBuilder.toString()
+
+    val tokenLength = buffer.getInt()
+    token = new Array[Byte](tokenLength)
+    if (tokenLength > 0) {
+      buffer.get(token, 0, tokenLength)
+    }
+  }
+
+  def set(bufferMsg: BufferMessage) {
+    val buffer = bufferMsg.buffers.apply(0)
+    buffer.clear()
+    set(buffer)
+  }
+
+  def getConnectionId: String = {
+    return connectionId
+  }
+
+  def getToken: Array[Byte] = {
+    return token
+  }
+
+  /**
+   * Create a BufferMessage that can be sent by the ConnectionManager 
containing
+   * the security information from this class.
+   * @return BufferMessage
+   */
+  def toBufferMessage: BufferMessage = {
+    val buffers = new ArrayBuffer[ByteBuffer]()
+
+    // 4 bytes for the length of the connectionId
+    // connectionId is of type char so multiple the length by 2 to get number 
of bytes
+    // 4 bytes for the length of token
+    // token is a byte buffer so just take the length
+    var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + 
token.length)
+    buffer.putInt(connectionId.length())
+    connectionId.foreach((x: Char) => buffer.putChar(x))
+    buffer.putInt(token.length)
+
+    if (token.length > 0) {
+      buffer.put(token)
+    }
+    buffer.flip()
+    buffers += buffer
+
+    var message = Message.createBufferMessage(buffers)
+    logDebug("message total size is : " + message.size)
+    message.isSecurityNeg = true
+    return message
+  }
+
+  override def toString: String = {
+    "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]"
+  }
+}
+
+private[nio] object SecurityMessage {
+
+  /**
+   * Convert the given BufferMessage to a SecurityMessage by parsing the 
contents
+   * of the BufferMessage and populating the SecurityMessage fields.
+   * @param bufferMessage is a BufferMessage that was received
+   * @return new SecurityMessage
+   */
+  def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = {
+    val newSecurityMessage = new SecurityMessage()
+    newSecurityMessage.set(bufferMessage)
+    newSecurityMessage
+  }
+
+  /**
+   * Create a SecurityMessage to send from a given saslResponse.
+   * @param response is the response to a challenge from the SaslClient or 
Saslserver
+   * @param connectionId the client connectionId we are negotiation 
authentication for
+   * @return a new SecurityMessage
+   */
+  def fromResponse(response : Array[Byte], connectionId : String) : 
SecurityMessage = {
+    val newSecurityMessage = new SecurityMessage()
+    newSecurityMessage.set(response, connectionId)
+    newSecurityMessage
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala 
b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 87ef9bb..d6386f8 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -27,9 +27,9 @@ import com.twitter.chill.{AllScalaRegistrar, 
EmptyScalaKryoInstantiator}
 
 import org.apache.spark._
 import org.apache.spark.broadcast.HttpBroadcast
+import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock}
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.storage._
-import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
 import org.apache.spark.util.BoundedPriorityQueue
 import org.apache.spark.util.collection.CompactBuffer
 

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala 
b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 96faccc..439981d 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -26,6 +26,7 @@ import scala.collection.JavaConversions._
 
 import org.apache.spark.{SparkEnv, SparkConf, Logging}
 import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
 import org.apache.spark.storage._
@@ -166,34 +167,30 @@ class FileShuffleBlockManager(conf: SparkConf)
     }
   }
 
-  /**
-   * Returns the physical file segment in which the given BlockId is located.
-   */
-  private def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+  override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
+    val segment = getBlockData(blockId)
+    Some(segment.nioByteBuffer())
+  }
+
+  override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
     if (consolidateShuffleFiles) {
       // Search all file groups associated with this shuffle.
-      val shuffleState = shuffleStates(id.shuffleId)
+      val shuffleState = shuffleStates(blockId.shuffleId)
       val iter = shuffleState.allFileGroups.iterator
       while (iter.hasNext) {
-        val segment = iter.next.getFileSegmentFor(id.mapId, id.reduceId)
-        if (segment.isDefined) { return segment.get }
+        val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, 
blockId.reduceId)
+        if (segmentOpt.isDefined) {
+          val segment = segmentOpt.get
+          return new FileSegmentManagedBuffer(segment.file, segment.offset, 
segment.length)
+        }
       }
-      throw new IllegalStateException("Failed to find shuffle block: " + id)
+      throw new IllegalStateException("Failed to find shuffle block: " + 
blockId)
     } else {
-      val file = blockManager.diskBlockManager.getFile(id)
-      new FileSegment(file, 0, file.length())
+      val file = blockManager.diskBlockManager.getFile(blockId)
+      new FileSegmentManagedBuffer(file, 0, file.length)
     }
   }
 
-  override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
-    val segment = getBlockLocation(blockId)
-    blockManager.diskStore.getBytes(segment)
-  }
-
-  override def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, 
ByteBuffer] = {
-    Left(getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]))
-  }
-
   /** Remove all the blocks / files and metadata related to a particular 
shuffle. */
   def removeShuffle(shuffleId: ShuffleId): Boolean = {
     // Do not change the ordering of this, if shuffleStates should be removed 
only

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index 8bb9efc..4ab3433 100644
--- 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -21,6 +21,7 @@ import java.io._
 import java.nio.ByteBuffer
 
 import org.apache.spark.SparkEnv
+import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer}
 import org.apache.spark.storage._
 
 /**
@@ -89,10 +90,11 @@ class IndexShuffleBlockManager extends ShuffleBlockManager {
     }
   }
 
-  /**
-   * Get the location of a block in a map output file. Uses the index file we 
create for it.
-   * */
-  private def getBlockLocation(blockId: ShuffleBlockId): FileSegment = {
+  override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
+    Some(getBlockData(blockId).nioByteBuffer())
+  }
+
+  override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
     // The block is actually going to be a range of a single map output file 
for this map, so
     // find out the consolidated file, then the offset within that from our 
index
     val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
@@ -102,20 +104,14 @@ class IndexShuffleBlockManager extends 
ShuffleBlockManager {
       in.skip(blockId.reduceId * 8)
       val offset = in.readLong()
       val nextOffset = in.readLong()
-      new FileSegment(getDataFile(blockId.shuffleId, blockId.mapId), offset, 
nextOffset - offset)
+      new FileSegmentManagedBuffer(
+        getDataFile(blockId.shuffleId, blockId.mapId),
+        offset,
+        nextOffset - offset)
     } finally {
       in.close()
     }
   }
 
-  override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
-    val segment = getBlockLocation(blockId)
-    blockManager.diskStore.getBytes(segment)
-  }
-
-  override def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, 
ByteBuffer] = {
-    Left(getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]))
-  }
-
   override def stop() = {}
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala 
b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
index 4240580..63863cc 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -19,7 +19,8 @@ package org.apache.spark.shuffle
 
 import java.nio.ByteBuffer
 
-import org.apache.spark.storage.{FileSegment, ShuffleBlockId}
+import org.apache.spark.network.ManagedBuffer
+import org.apache.spark.storage.ShuffleBlockId
 
 private[spark]
 trait ShuffleBlockManager {
@@ -31,8 +32,7 @@ trait ShuffleBlockManager {
    */
   def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer]
 
-  def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer]
+  def getBlockData(blockId: ShuffleBlockId): ManagedBuffer
 
   def stop(): Unit
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
 
b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 12b4756..6cf9305 100644
--- 
a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -21,10 +21,9 @@ import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 
 import org.apache.spark._
-import org.apache.spark.executor.ShuffleReadMetrics
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
+import org.apache.spark.storage.{BlockId, BlockManagerId, 
ShuffleBlockFetcherIterator, ShuffleBlockId}
 import org.apache.spark.util.CompletionIterator
 
 private[hash] object BlockStoreShuffleFetcher extends Logging {
@@ -32,8 +31,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging 
{
       shuffleId: Int,
       reduceId: Int,
       context: TaskContext,
-      serializer: Serializer,
-      shuffleMetrics: ShuffleReadMetrics)
+      serializer: Serializer)
     : Iterator[T] =
   {
     logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, 
reduceId))
@@ -74,7 +72,13 @@ private[hash] object BlockStoreShuffleFetcher extends 
Logging {
       }
     }
 
-    val blockFetcherItr = blockManager.getMultiple(blocksByAddress, 
serializer, shuffleMetrics)
+    val blockFetcherItr = new ShuffleBlockFetcherIterator(
+      context,
+      SparkEnv.get.blockTransferService,
+      blockManager,
+      blocksByAddress,
+      serializer,
+      SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 
1024)
     val itr = blockFetcherItr.flatMap(unpackBlock)
 
     val completionIter = CompletionIterator[T, Iterator[T]](itr, {

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala 
b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 7bed97a..88a5f1e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -36,10 +36,8 @@ private[spark] class HashShuffleReader[K, C](
 
   /** Read the combined key-values for this reduce task */
   override def read(): Iterator[Product2[K, C]] = {
-    val readMetrics = 
context.taskMetrics.createShuffleReadMetricsForDependency()
     val ser = Serializer.getSerializer(dep.serializer)
-    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, 
startPartition, context, ser,
-      readMetrics)
+    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, 
startPartition, context, ser)
 
     val aggregatedIter: Iterator[Product2[K, C]] = if 
(dep.aggregator.isDefined) {
       if (dep.mapSideCombine) {

http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
deleted file mode 100644
index e35b7fe..0000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ /dev/null
@@ -1,254 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.util.concurrent.LinkedBlockingQueue
-import org.apache.spark.network.netty.client.{BlockClientListener, 
LazyInitIterator, ReferenceCountedBuffer}
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Queue
-import scala.util.{Failure, Success}
-
-import org.apache.spark.{Logging, SparkException}
-import org.apache.spark.executor.ShuffleReadMetrics
-import org.apache.spark.network.BufferMessage
-import org.apache.spark.network.ConnectionManagerId
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
-
-/**
- * A block fetcher iterator interface for fetching shuffle blocks.
- */
-private[storage]
-trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] 
with Logging {
-  def initialize()
-}
-
-
-private[storage]
-object BlockFetcherIterator {
-
-  /**
-   * A request to fetch blocks from a remote BlockManager.
-   * @param address remote BlockManager to fetch from.
-   * @param blocks Sequence of tuple, where the first element is the block id,
-   *               and the second element is the estimated size, used to 
calculate bytesInFlight.
-   */
-  class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, 
Long)]) {
-    val size = blocks.map(_._2).sum
-  }
-
-  /**
-   * Result of a fetch from a remote block. A failure is represented as size 
== -1.
-   * @param blockId block id
-   * @param size estimated size of the block, used to calculate bytesInFlight.
-   *             Note that this is NOT the exact bytes.
-   * @param deserialize closure to return the result in the form of an 
Iterator.
-   */
-  class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () 
=> Iterator[Any]) {
-    def failed: Boolean = size == -1
-  }
-
-  // TODO: Refactor this whole thing to make code more reusable.
-  class BasicBlockFetcherIterator(
-      private val blockManager: BlockManager,
-      val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
-      serializer: Serializer,
-      readMetrics: ShuffleReadMetrics)
-    extends BlockFetcherIterator {
-
-    import blockManager._
-
-    if (blocksByAddress == null) {
-      throw new IllegalArgumentException("BlocksByAddress is null")
-    }
-
-    // Total number blocks fetched (local + remote). Also number of 
FetchResults expected
-    protected var _numBlocksToFetch = 0
-
-    protected var startTime = System.currentTimeMillis
-
-    // BlockIds for local blocks that need to be fetched. Excludes zero-sized 
blocks
-    protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
-
-    // BlockIds for remote blocks that need to be fetched. Excludes zero-sized 
blocks
-    protected val remoteBlocksToFetch = new HashSet[BlockId]()
-
-    // A queue to hold our results.
-    protected val results = new LinkedBlockingQueue[FetchResult]
-
-    // Queue of fetch requests to issue; we'll pull requests off this 
gradually to make sure that
-    // the number of bytes in flight is limited to maxBytesInFlight
-    protected val fetchRequests = new Queue[FetchRequest]
-
-    // Current bytes in flight from our requests
-    protected var bytesInFlight = 0L
-
-    protected def sendRequest(req: FetchRequest) {
-      logDebug("Sending request for %d blocks (%s) from %s".format(
-        req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
-      val cmId = new ConnectionManagerId(req.address.host, req.address.port)
-      val blockMessageArray = new BlockMessageArray(req.blocks.map {
-        case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
-      })
-      bytesInFlight += req.size
-      val sizeMap = req.blocks.toMap  // so we can look up the size of each 
blockID
-      val future = connectionManager.sendMessageReliably(cmId, 
blockMessageArray.toBufferMessage)
-      future.onComplete {
-        case Success(message) => {
-          val bufferMessage = message.asInstanceOf[BufferMessage]
-          val blockMessageArray = 
BlockMessageArray.fromBufferMessage(bufferMessage)
-          for (blockMessage <- blockMessageArray) {
-            if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
-              throw new SparkException(
-                "Unexpected message " + blockMessage.getType + " received from 
" + cmId)
-            }
-            val blockId = blockMessage.getId
-            val networkSize = blockMessage.getData.limit()
-            results.put(new FetchResult(blockId, sizeMap(blockId),
-              () => dataDeserialize(blockId, blockMessage.getData, 
serializer)))
-            // TODO: NettyBlockFetcherIterator has some race conditions where 
multiple threads can
-            // be incrementing bytes read at the same time (SPARK-2625).
-            readMetrics.remoteBytesRead += networkSize
-            readMetrics.remoteBlocksFetched += 1
-            logDebug("Got remote block " + blockId + " after " + 
Utils.getUsedTimeMs(startTime))
-          }
-        }
-        case Failure(exception) => {
-          logError("Could not get block(s) from " + cmId, exception)
-          for ((blockId, size) <- req.blocks) {
-            results.put(new FetchResult(blockId, -1, null))
-          }
-        }
-      }
-    }
-
-    protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
-      // Make remote requests at most maxBytesInFlight / 5 in length; the 
reason to keep them
-      // smaller than maxBytesInFlight is to allow multiple, parallel fetches 
from up to 5
-      // nodes, rather than blocking on reading output from one node.
-      val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
-      logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: 
" + targetRequestSize)
-
-      // Split local and remote blocks. Remote blocks are further split into 
FetchRequests of size
-      // at most maxBytesInFlight in order to limit the amount of data in 
flight.
-      val remoteRequests = new ArrayBuffer[FetchRequest]
-      var totalBlocks = 0
-      for ((address, blockInfos) <- blocksByAddress) {
-        totalBlocks += blockInfos.size
-        if (address == blockManagerId) {
-          // Filter out zero-sized blocks
-          localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
-          _numBlocksToFetch += localBlocksToFetch.size
-        } else {
-          val iterator = blockInfos.iterator
-          var curRequestSize = 0L
-          var curBlocks = new ArrayBuffer[(BlockId, Long)]
-          while (iterator.hasNext) {
-            val (blockId, size) = iterator.next()
-            // Skip empty blocks
-            if (size > 0) {
-              curBlocks += ((blockId, size))
-              remoteBlocksToFetch += blockId
-              _numBlocksToFetch += 1
-              curRequestSize += size
-            } else if (size < 0) {
-              throw new BlockException(blockId, "Negative block size " + size)
-            }
-            if (curRequestSize >= targetRequestSize) {
-              // Add this FetchRequest
-              remoteRequests += new FetchRequest(address, curBlocks)
-              curBlocks = new ArrayBuffer[(BlockId, Long)]
-              logDebug(s"Creating fetch request of $curRequestSize at 
$address")
-              curRequestSize = 0
-            }
-          }
-          // Add in the final request
-          if (!curBlocks.isEmpty) {
-            remoteRequests += new FetchRequest(address, curBlocks)
-          }
-        }
-      }
-      logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
-        totalBlocks + " blocks")
-      remoteRequests
-    }
-
-    protected def getLocalBlocks() {
-      // Get the local blocks while remote blocks are being fetched. Note that 
it's okay to do
-      // these all at once because they will just memory-map some files, so 
they won't consume
-      // any memory that might exceed our maxBytesInFlight
-      for (id <- localBlocksToFetch) {
-        try {
-          readMetrics.localBlocksFetched += 1
-          results.put(new FetchResult(id, 0, () => getLocalShuffleFromDisk(id, 
serializer).get))
-          logDebug("Got local block " + id)
-        } catch {
-          case e: Exception => {
-            logError(s"Error occurred while fetching local blocks", e)
-            results.put(new FetchResult(id, -1, null))
-            return
-          }
-        }
-      }
-    }
-
-    override def initialize() {
-      // Split local and remote blocks.
-      val remoteRequests = splitLocalRemoteBlocks()
-      // Add the remote requests into our queue in a random order
-      fetchRequests ++= Utils.randomize(remoteRequests)
-
-      // Send out initial requests for blocks, up to our maxBytesInFlight
-      while (!fetchRequests.isEmpty &&
-        (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= 
maxBytesInFlight)) {
-        sendRequest(fetchRequests.dequeue())
-      }
-
-      val numFetches = remoteRequests.size - fetchRequests.size
-      logInfo("Started " + numFetches + " remote fetches in" + 
Utils.getUsedTimeMs(startTime))
-
-      // Get Local Blocks
-      startTime = System.currentTimeMillis
-      getLocalBlocks()
-      logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
-    }
-
-    // Implementing the Iterator methods with an iterator that reads fetched 
blocks off the queue
-    // as they arrive.
-    @volatile protected var resultsGotten = 0
-
-    override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
-
-    override def next(): (BlockId, Option[Iterator[Any]]) = {
-      resultsGotten += 1
-      val startFetchWait = System.currentTimeMillis()
-      val result = results.take()
-      val stopFetchWait = System.currentTimeMillis()
-      readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
-      if (! result.failed) bytesInFlight -= result.size
-      while (!fetchRequests.isEmpty &&
-        (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= 
maxBytesInFlight)) {
-        sendRequest(fetchRequests.dequeue())
-      }
-      (result.blockId, if (result.failed) None else Some(result.deserialize()))
-    }
-  }
-  // End of BasicBlockFetcherIterator
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to