http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala new file mode 100644 index 0000000..09d3ea3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -0,0 +1,1042 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.io.IOException +import java.net._ +import java.nio._ +import java.nio.channels._ +import java.nio.channels.spi._ +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} +import java.util.{Timer, TimerTask} + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} +import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future, Promise} +import scala.language.postfixOps + +import org.apache.spark._ +import org.apache.spark.util.{SystemClock, Utils} + + +private[nio] class ConnectionManager( + port: Int, + conf: SparkConf, + securityManager: SecurityManager, + name: String = "Connection manager") + extends Logging { + + /** + * Used by sendMessageReliably to track messages being sent. + * @param message the message that was sent + * @param connectionManagerId the connection manager that sent this message + * @param completionHandler callback that's invoked when the send has completed or failed + */ + class MessageStatus( + val message: Message, + val connectionManagerId: ConnectionManagerId, + completionHandler: MessageStatus => Unit) { + + /** This is non-None if message has been ack'd */ + var ackMessage: Option[Message] = None + + def markDone(ackMessage: Option[Message]) { + this.ackMessage = ackMessage + completionHandler(this) + } + } + + private val selector = SelectorProvider.provider.openSelector() + private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) + + // default to 30 second timeout waiting for authentication + private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) + private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) + + private val handleMessageExecutor = new ThreadPoolExecutor( + conf.getInt("spark.core.connection.handler.threads.min", 20), + conf.getInt("spark.core.connection.handler.threads.max", 60), + conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable](), + Utils.namedThreadFactory("handle-message-executor")) + + private val handleReadWriteExecutor = new ThreadPoolExecutor( + conf.getInt("spark.core.connection.io.threads.min", 4), + conf.getInt("spark.core.connection.io.threads.max", 32), + conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable](), + Utils.namedThreadFactory("handle-read-write-executor")) + + // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : + // which should be executed asap + private val handleConnectExecutor = new ThreadPoolExecutor( + conf.getInt("spark.core.connection.connect.threads.min", 1), + conf.getInt("spark.core.connection.connect.threads.max", 8), + conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable](), + Utils.namedThreadFactory("handle-connect-executor")) + + private val serverChannel = ServerSocketChannel.open() + // used to track the SendingConnections waiting to do SASL negotiation + private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] + with SynchronizedMap[ConnectionId, SendingConnection] + private val connectionsByKey = + new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] + private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] + with SynchronizedMap[ConnectionManagerId, SendingConnection] + private val messageStatuses = new HashMap[Int, MessageStatus] + private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + private val registerRequests = new SynchronizedQueue[SendingConnection] + + implicit val futureExecContext = ExecutionContext.fromExecutor( + Utils.newDaemonCachedThreadPool("Connection manager future execution context")) + + @volatile + private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null + + private val authEnabled = securityManager.isAuthenticationEnabled() + + serverChannel.configureBlocking(false) + serverChannel.socket.setReuseAddress(true) + serverChannel.socket.setReceiveBufferSize(256 * 1024) + + private def startService(port: Int): (ServerSocketChannel, Int) = { + serverChannel.socket.bind(new InetSocketAddress(port)) + (serverChannel, serverChannel.socket.getLocalPort) + } + Utils.startServiceOnPort[ServerSocketChannel](port, startService, name) + serverChannel.register(selector, SelectionKey.OP_ACCEPT) + + val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) + logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) + + // used in combination with the ConnectionManagerId to create unique Connection ids + // to be able to track asynchronous messages + private val idCount: AtomicInteger = new AtomicInteger(1) + + private val selectorThread = new Thread("connection-manager-thread") { + override def run() = ConnectionManager.this.run() + } + selectorThread.setDaemon(true) + selectorThread.start() + + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerWrite(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + writeRunnableStarted.synchronized { + // So that we do not trigger more write events while processing this one. + // The write method will re-register when done. + if (conn.changeInterestForWrite()) conn.unregisterInterest() + if (writeRunnableStarted.contains(key)) { + // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) + return + } + + writeRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + val needReregister = register || conn.resetForceReregister() + if (needReregister && conn.changeInterestForWrite()) { + conn.registerInterest() + } + } + } + } + } ) + } + + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerRead(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + readRunnableStarted.synchronized { + // So that we do not trigger more read events while processing this one. + // The read method will re-register when done. + if (conn.changeInterestForRead())conn.unregisterInterest() + if (readRunnableStarted.contains(key)) { + return + } + + readRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } + } + } + } + } ) + } + + private def triggerConnect(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] + if (conn == null) return + + // prevent other events from being triggered + // Since we are still trying to connect, we do not need to do the additional steps in + // triggerWrite + conn.changeConnectionKeyInterest(0) + + handleConnectExecutor.execute(new Runnable { + override def run() { + + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } + + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need + // not succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } + } ) + } + + // MUST be called within selector loop - else deadlock. + private def triggerForceCloseByException(key: SelectionKey, e: Exception) { + try { + key.interestOps(0) + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + // Pushing to connect threadpool + handleConnectExecutor.execute(new Runnable { + override def run() { + try { + conn.callOnExceptionCallback(e) + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + try { + conn.close() + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + } + }) + } + + + def run() { + try { + while(!selectorThread.isInterrupted) { + while (!registerRequests.isEmpty) { + val conn: SendingConnection = registerRequests.dequeue() + addListeners(conn) + conn.connect() + addConnection(conn) + } + + while(!keyInterestChangeRequests.isEmpty) { + val (key, ops) = keyInterestChangeRequests.dequeue() + + try { + if (key.isValid) { + val connection = connectionsByKey.getOrElse(key, null) + if (connection != null) { + val lastOps = key.interestOps() + key.interestOps(ops) + + // hot loop - prevent materialization of string if trace not enabled. + if (isTraceEnabled()) { + def intToOpStr(op: Int): String = { + val opStrs = ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed key for connection to [" + + connection.getRemoteConnectionManagerId() + "] changed from [" + + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } + } + } else { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } + } + } + + val selectedKeysCount = + try { + selector.select() + } catch { + // Explicitly only dealing with CancelledKeyException here since other exceptions + // should be dealt with differently. + case e: CancelledKeyException => { + // Some keys within the selectors list are invalid/closed. clear them. + val allKeys = selector.keys().iterator() + + while (allKeys.hasNext) { + val key = allKeys.next() + try { + if (! key.isValid) { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } + } + } + } + 0 + } + + if (selectedKeysCount == 0) { + logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + + " keys") + } + if (selectorThread.isInterrupted) { + logInfo("Selector thread was interrupted!") + return + } + + if (0 != selectedKeysCount) { + val selectedKeys = selector.selectedKeys().iterator() + while (selectedKeys.hasNext) { + val key = selectedKeys.next + selectedKeys.remove() + try { + if (key.isValid) { + if (key.isAcceptable) { + acceptConnection(key) + } else + if (key.isConnectable) { + triggerConnect(key) + } else + if (key.isReadable) { + triggerRead(key) + } else + if (key.isWritable) { + triggerWrite(key) + } + } else { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + // weird, but we saw this happening - even though key.isValid was true, + // key.isAcceptable would throw CancelledKeyException. + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } + } + } + } + } + } catch { + case e: Exception => logError("Error in select loop", e) + } + } + + def acceptConnection(key: SelectionKey) { + val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] + + var newChannel = serverChannel.accept() + + // accept them all in a tight loop. non blocking accept with no processing, should be fine + while (newChannel != null) { + try { + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId) + newConnection.onReceive(receiveMessage) + addListeners(newConnection) + addConnection(newConnection) + logInfo("Accepted connection from [" + newConnection.remoteAddress + "]") + } catch { + // might happen in case of issues with registering with selector + case e: Exception => logError("Error in accept loop", e) + } + + newChannel = serverChannel.accept() + } + } + + private def addListeners(connection: Connection) { + connection.onKeyInterestChange(changeConnectionKeyInterest) + connection.onException(handleConnectionError) + connection.onClose(removeConnection) + } + + def addConnection(connection: Connection) { + connectionsByKey += ((connection.key, connection)) + } + + def removeConnection(connection: Connection) { + connectionsByKey -= connection.key + + try { + connection match { + case sendingConnection: SendingConnection => + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + logInfo("Removing SendingConnection to " + sendingConnectionManagerId) + + connectionsById -= sendingConnectionManagerId + connectionsAwaitingSasl -= connection.connectionId + + messageStatuses.synchronized { + messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) + .foreach(status => { + logInfo("Notifying " + status) + status.markDone(None) + }) + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } + case receivingConnection: ReceivingConnection => + val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() + logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) + + val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) + if (!sendingConnectionOpt.isDefined) { + logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found") + return + } + + val sendingConnection = sendingConnectionOpt.get + connectionsById -= remoteConnectionManagerId + sendingConnection.close() + + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + + assert(sendingConnectionManagerId == remoteConnectionManagerId) + + messageStatuses.synchronized { + for (s <- messageStatuses.values + if s.connectionManagerId == sendingConnectionManagerId) { + logInfo("Notifying " + s) + s.markDone(None) + } + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } + case _ => logError("Unsupported type of connection.") + } + } finally { + // So that the selection keys can be removed. + wakeupSelector() + } + } + + def handleConnectionError(connection: Connection, e: Exception) { + logInfo("Handling connection error on connection to " + + connection.getRemoteConnectionManagerId()) + removeConnection(connection) + } + + def changeConnectionKeyInterest(connection: Connection, ops: Int) { + keyInterestChangeRequests += ((connection.key, ops)) + // so that registerations happen ! + wakeupSelector() + } + + def receiveMessage(connection: Connection, message: Message) { + val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) + logDebug("Received [" + message + "] from [" + connectionManagerId + "]") + val runnable = new Runnable() { + val creationTime = System.currentTimeMillis + def run() { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message, connection) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } + } + handleMessageExecutor.execute(runnable) + /* handleMessage(connection, message) */ + } + + private def handleClientAuthentication( + waitingConn: SendingConnection, + securityMsg: SecurityMessage, + connectionId : ConnectionId) { + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll() + } + return + } else { + var replyToken : Array[Byte] = null + try { + replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll() + } + return + } + val securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId.toString) + val message = securityMsgResp.toBufferMessage + if (message == null) throw new IOException("Error creating security message") + sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) + } catch { + case e: Exception => { + logError("Error handling sasl client authentication", e) + waitingConn.close() + throw new IOException("Error evaluating sasl response: ", e) + } + } + } + } + + private def handleServerAuthentication( + connection: Connection, + securityMsg: SecurityMessage, + connectionId: ConnectionId) { + if (!connection.isSaslComplete()) { + logDebug("saslContext not established") + var replyToken : Array[Byte] = null + try { + connection.synchronized { + if (connection.sparkSaslServer == null) { + logDebug("Creating sasl Server") + connection.sparkSaslServer = new SparkSaslServer(securityManager) + } + } + replyToken = connection.sparkSaslServer.response(securityMsg.getToken) + if (connection.isSaslComplete()) { + logDebug("Server sasl completed: " + connection.connectionId) + } else { + logDebug("Server sasl not completed: " + connection.connectionId) + } + if (replyToken != null) { + val securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId) + val message = securityMsgResp.toBufferMessage + if (message == null) throw new Exception("Error creating security Message") + sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) + } + } catch { + case e: Exception => { + logError("Error in server auth negotiation: " + e) + // It would probably be better to send an error message telling other side auth failed + // but for now just close + connection.close() + } + } + } else { + logDebug("connection already established for this connection id: " + connection.connectionId) + } + } + + + private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { + if (bufferMessage.isSecurityNeg) { + logDebug("This is security neg message") + + // parse as SecurityMessage + val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) + val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) + + connectionsAwaitingSasl.get(connectionId) match { + case Some(waitingConn) => { + // Client - this must be in response to us doing Send + logDebug("Client handleAuth for id: " + waitingConn.connectionId) + handleClientAuthentication(waitingConn, securityMsg, connectionId) + } + case None => { + // Server - someone sent us something and we haven't authenticated yet + logDebug("Server handleAuth for id: " + connectionId) + handleServerAuthentication(conn, securityMsg, connectionId) + } + } + return true + } else { + if (!conn.isSaslComplete()) { + // We could handle this better and tell the client we need to do authentication + // negotiation, but for now just ignore them. + logError("message sent that is not security negotiation message on connection " + + "not authenticated yet, ignoring it!!") + return true + } + } + false + } + + private def handleMessage( + connectionManagerId: ConnectionManagerId, + message: Message, + connection: Connection) { + logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") + message match { + case bufferMessage: BufferMessage => { + if (authEnabled) { + val res = handleAuthentication(connection, bufferMessage) + if (res) { + // message was security negotiation so skip the rest + logDebug("After handleAuth result was true, returning") + return + } + } + if (bufferMessage.hasAckId()) { + messageStatuses.synchronized { + messageStatuses.get(bufferMessage.ackId) match { + case Some(status) => { + messageStatuses -= bufferMessage.ackId + status.markDone(Some(message)) + } + case None => { + /** + * We can fall down on this code because of following 2 cases + * + * (1) Invalid ack sent due to buggy code. + * + * (2) Late-arriving ack for a SendMessageStatus + * To avoid unwilling late-arriving ack + * caused by long pause like GC, you can set + * larger value than default to spark.core.connection.ack.wait.timeout + */ + logWarning(s"Could not find reference for received ack Message ${message.id}") + } + } + } + } else { + var ackMessage : Option[Message] = None + try { + ackMessage = if (onReceiveCallback != null) { + logDebug("Calling back") + onReceiveCallback(bufferMessage, connectionManagerId) + } else { + logDebug("Not calling back as callback is null") + None + } + + if (ackMessage.isDefined) { + if (!ackMessage.get.isInstanceOf[BufferMessage]) { + logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + + ackMessage.get.getClass) + } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { + logDebug("Response to " + bufferMessage + " does not have ack id set") + ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + } + } + } catch { + case e: Exception => { + logError(s"Exception was thrown while processing message", e) + val m = Message.createBufferMessage(bufferMessage.id) + m.hasError = true + ackMessage = Some(m) + } + } finally { + sendMessage(connectionManagerId, ackMessage.getOrElse { + Message.createBufferMessage(bufferMessage.id) + }) + } + } + } + case _ => throw new Exception("Unknown type message received") + } + } + + private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { + // see if we need to do sasl before writing + // this should only be the first negotiation as the Client!!! + if (!conn.isSaslComplete()) { + conn.synchronized { + if (conn.sparkSaslClient == null) { + conn.sparkSaslClient = new SparkSaslClient(securityManager) + var firstResponse: Array[Byte] = null + try { + firstResponse = conn.sparkSaslClient.firstToken() + val securityMsg = SecurityMessage.fromResponse(firstResponse, + conn.connectionId.toString()) + val message = securityMsg.toBufferMessage + if (message == null) throw new Exception("Error creating security message") + connectionsAwaitingSasl += ((conn.connectionId, conn)) + sendSecurityMessage(connManagerId, message) + logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId) + } catch { + case e: Exception => { + logError("Error getting first response from the SaslClient.", e) + conn.close() + throw new Exception("Error getting first response from the SaslClient") + } + } + } + } + } else { + logDebug("Sasl already established ") + } + } + + // allow us to add messages to the inbox for doing sasl negotiating + private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { + def startNewConnection(): SendingConnection = { + val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, + newConnectionId) + logInfo("creating new sending connection for security! " + newConnectionId ) + registerRequests.enqueue(newConnection) + + newConnection + } + // I removed the lookupKey stuff as part of merge ... should I re-add it ? + // We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + message.senderAddress = id.toSocketAddress() + logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") + val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) + + // send security message until going connection has been authenticated + connection.send(message) + + wakeupSelector() + } + + private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { + def startNewConnection(): SendingConnection = { + val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, + connectionManagerId.port) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, + newConnectionId) + logTrace("creating new sending connection: " + newConnectionId) + registerRequests.enqueue(newConnection) + + newConnection + } + val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) + if (authEnabled) { + checkSendAuthFirst(connectionManagerId, connection) + } + message.senderAddress = id.toSocketAddress() + logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + + "connectionid: " + connection.connectionId) + + if (authEnabled) { + // if we aren't authenticated yet lets block the senders until authentication completes + try { + connection.getAuthenticated().synchronized { + val clock = SystemClock + val startTime = clock.getTime() + + while (!connection.isSaslComplete()) { + logDebug("getAuthenticated wait connectionid: " + connection.connectionId) + // have timeout in case remote side never responds + connection.getAuthenticated().wait(500) + if (((clock.getTime() - startTime) >= (authTimeout * 1000)) + && (!connection.isSaslComplete())) { + // took to long to authenticate the connection, something probably went wrong + throw new Exception("Took to long for authentication to " + connectionManagerId + + ", waited " + authTimeout + "seconds, failing.") + } + } + } + } catch { + case e: Exception => logError("Exception while waiting for authentication.", e) + + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(message.id) + s match { + case Some(msgStatus) => { + messageStatuses -= message.id + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.markDone(None) + } + case None => { + logError("no messageStatus for failed message id: " + message.id) + } + } + } + } + } + logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") + connection.send(message) + + wakeupSelector() + } + + private def wakeupSelector() { + selector.wakeup() + } + + /** + * Send a message and block until an acknowldgment is received or an error occurs. + * @param connectionManagerId the message's destination + * @param message the message being sent + * @return a Future that either returns the acknowledgment message or captures an exception. + */ + def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) + : Future[Message] = { + val promise = Promise[Message]() + + val timeoutTask = new TimerTask { + override def run(): Unit = { + messageStatuses.synchronized { + messageStatuses.remove(message.id).foreach ( s => { + promise.failure( + new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec")) + }) + } + } + } + + val status = new MessageStatus(message, connectionManagerId, s => { + timeoutTask.cancel() + s.ackMessage match { + case None => // Indicates a failure where we either never sent or never got ACK'd + promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) + case Some(ackMessage) => + if (ackMessage.hasError) { + promise.failure( + new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + } else { + promise.success(ackMessage) + } + } + }) + messageStatuses.synchronized { + messageStatuses += ((message.id, status)) + } + + ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) + sendMessage(connectionManagerId, message) + promise.future + } + + def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { + onReceiveCallback = callback + } + + def stop() { + ackTimeoutMonitor.cancel() + selectorThread.interrupt() + selectorThread.join() + selector.close() + val connections = connectionsByKey.values + connections.foreach(_.close()) + if (connectionsByKey.size != 0) { + logWarning("All connections not cleaned up") + } + handleMessageExecutor.shutdown() + handleReadWriteExecutor.shutdown() + handleConnectExecutor.shutdown() + logInfo("ConnectionManager stopped") + } +} + + +private[spark] object ConnectionManager { + import scala.concurrent.ExecutionContext.Implicits.global + + def main(args: Array[String]) { + val conf = new SparkConf + val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + println("Received [" + msg + "] from [" + id + "]") + None + }) + + /* testSequentialSending(manager) */ + /* System.gc() */ + + /* testParallelSending(manager) */ + /* System.gc() */ + + /* testParallelDecreasingSending(manager) */ + /* System.gc() */ + + testContinuousSending(manager) + System.gc() + } + + def testSequentialSending(manager: ConnectionManager) { + println("--------------------------") + println("Sequential Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) + }) + println("--------------------------") + println() + } + + def testParallelSending(manager: ConnectionManager) { + println("--------------------------") + println("Parallel Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(manager.id, bufferMessage) + }).foreach(f => { + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) + }) + val finishTime = System.currentTimeMillis + + val mb = size * count / 1024.0 / 1024.0 + val ms = finishTime - startTime + val tput = mb * 1000.0 / ms + println("--------------------------") + println("Started at " + startTime + ", finished at " + finishTime) + println("Sent " + count + " messages of size " + size + " in " + ms + " ms " + + "(" + tput + " MB/s)") + println("--------------------------") + println() + } + + def testParallelDecreasingSending(manager: ConnectionManager) { + println("--------------------------") + println("Parallel Decreasing Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + val buffers = Array.tabulate(count) { i => + val bufferLen = size * (i + 1) + val bufferContent = Array.tabulate[Byte](bufferLen)(x => x.toByte) + ByteBuffer.allocate(bufferLen).put(bufferContent) + } + buffers.foreach(_.flip) + val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 + + val startTime = System.currentTimeMillis + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) + manager.sendMessageReliably(manager.id, bufferMessage) + }).foreach(f => { + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) + }) + val finishTime = System.currentTimeMillis + + val ms = finishTime - startTime + val tput = mb * 1000.0 / ms + println("--------------------------") + /* println("Started at " + startTime + ", finished at " + finishTime) */ + println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") + println("--------------------------") + println() + } + + def testContinuousSending(manager: ConnectionManager) { + println("--------------------------") + println("Continuous Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + while(true) { + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(manager.id, bufferMessage) + }).foreach(f => { + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) + }) + val finishTime = System.currentTimeMillis + Thread.sleep(1000) + val mb = size * count / 1024.0 / 1024.0 + val ms = finishTime - startTime + val tput = mb * 1000.0 / ms + println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") + println("--------------------------") + println() + } + } +}
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala new file mode 100644 index 0000000..cbb37ec --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.net.InetSocketAddress + +import org.apache.spark.util.Utils + +private[nio] case class ConnectionManagerId(host: String, port: Int) { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + def toSocketAddress() = new InetSocketAddress(host, port) +} + + +private[nio] object ConnectionManagerId { + def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { + new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/Message.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala new file mode 100644 index 0000000..0b874c2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.net.InetSocketAddress +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + + +private[nio] abstract class Message(val typ: Long, val id: Int) { + var senderAddress: InetSocketAddress = null + var started = false + var startTime = -1L + var finishTime = -1L + var isSecurityNeg = false + var hasError = false + + def size: Int + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] + + def timeTaken(): String = (finishTime - startTime).toString + " ms" + + override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" +} + + +private[nio] object Message { + val BUFFER_MESSAGE = 1111111111L + + var lastId = 1 + + def getNewId() = synchronized { + lastId += 1 + if (lastId == 0) { + lastId += 1 + } + lastId + } + + def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { + if (dataBuffers == null) { + return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) + } + if (dataBuffers.exists(_ == null)) { + throw new Exception("Attempting to create buffer message with null buffer") + } + new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) + } + + def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = + createBufferMessage(dataBuffers, 0) + + def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { + if (dataBuffer == null) { + createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) + } else { + createBufferMessage(Array(dataBuffer), ackId) + } + } + + def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = + createBufferMessage(dataBuffer, 0) + + def createBufferMessage(ackId: Int): BufferMessage = { + createBufferMessage(new Array[ByteBuffer](0), ackId) + } + + def create(header: MessageChunkHeader): Message = { + val newMessage: Message = header.typ match { + case BUFFER_MESSAGE => new BufferMessage(header.id, + ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) + } + newMessage.hasError = header.hasError + newMessage.senderAddress = header.address + newMessage + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala new file mode 100644 index 0000000..278c5ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +private[nio] +class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { + + val size = if (buffer == null) 0 else buffer.remaining + + lazy val buffers = { + val ab = new ArrayBuffer[ByteBuffer]() + ab += header.buffer + if (buffer != null) { + ab += buffer + } + ab + } + + override def toString = { + "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala new file mode 100644 index 0000000..6e20f29 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.net.{InetAddress, InetSocketAddress} +import java.nio.ByteBuffer + +private[nio] class MessageChunkHeader( + val typ: Long, + val id: Int, + val totalSize: Int, + val chunkSize: Int, + val other: Int, + val hasError: Boolean, + val securityNeg: Int, + val address: InetSocketAddress) { + lazy val buffer = { + // No need to change this, at 'use' time, we do a reverse lookup of the hostname. + // Refer to network.Connection + val ip = address.getAddress.getAddress() + val port = address.getPort() + ByteBuffer. + allocate(MessageChunkHeader.HEADER_SIZE). + putLong(typ). + putInt(id). + putInt(totalSize). + putInt(chunkSize). + putInt(other). + put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). + putInt(securityNeg). + putInt(ip.size). + put(ip). + putInt(port). + position(MessageChunkHeader.HEADER_SIZE). + flip.asInstanceOf[ByteBuffer] + } + + override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + +} + + +private[nio] object MessageChunkHeader { + val HEADER_SIZE = 45 + + def create(buffer: ByteBuffer): MessageChunkHeader = { + if (buffer.remaining != HEADER_SIZE) { + throw new IllegalArgumentException("Cannot convert buffer data to Message") + } + val typ = buffer.getLong() + val id = buffer.getInt() + val totalSize = buffer.getInt() + val chunkSize = buffer.getInt() + val other = buffer.getInt() + val hasError = buffer.get() != 0 + val securityNeg = buffer.getInt() + val ipSize = buffer.getInt() + val ipBytes = new Array[Byte](ipSize) + buffer.get(ipBytes) + val ip = InetAddress.getByAddress(ipBytes) + val port = buffer.getInt() + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, + new InetSocketAddress(ip, port)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala new file mode 100644 index 0000000..59958ee --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.nio.ByteBuffer + +import scala.concurrent.Future + +import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} +import org.apache.spark.network._ +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.util.Utils + + +/** + * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a custom + * implementation using Java NIO. + */ +final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService with Logging { + + private var cm: ConnectionManager = _ + + private var blockDataManager: BlockDataManager = _ + + /** + * Port number the service is listening on, available only after [[init]] is invoked. + */ + override def port: Int = { + checkInit() + cm.id.port + } + + /** + * Host name the service is listening on, available only after [[init]] is invoked. + */ + override def hostName: String = { + checkInit() + cm.id.host + } + + /** + * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch + * local blocks or put local blocks. + */ + override def init(blockDataManager: BlockDataManager): Unit = { + this.blockDataManager = blockDataManager + cm = new ConnectionManager( + conf.getInt("spark.blockManager.port", 0), + conf, + securityManager, + "Connection manager for block manager") + cm.onReceiveMessage(onBlockMessageReceive) + } + + /** + * Tear down the transfer service. + */ + override def stop(): Unit = { + if (cm != null) { + cm.stop() + } + } + + override def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit = { + checkInit() + + val cmId = new ConnectionManagerId(hostName, port) + val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => + BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) + }) + + val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) + + // Register the listener on success/failure future callback. + future.onSuccess { case message => + val bufferMessage = message.asInstanceOf[BufferMessage] + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + listener.onBlockFetchFailure( + new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + } else { + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + listener.onBlockFetchSuccess( + blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData)) + } + } + }(cm.futureExecContext) + + future.onFailure { case exception => + listener.onBlockFetchFailure(exception) + }(cm.futureExecContext) + } + + /** + * Upload a single block to a remote node, available only after [[init]] is invoked. + * + * This call blocks until the upload completes, or throws an exception upon failures. + */ + override def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel) + : Future[Unit] = { + checkInit() + val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level) + val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) + val remoteCmId = new ConnectionManagerId(hostName, port) + val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) + reply.map(x => ())(cm.futureExecContext) + } + + private def checkInit(): Unit = if (cm == null) { + throw new IllegalStateException(getClass.getName + " has not been initialized") + } + + private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { + logDebug("Handling message " + msg) + msg match { + case bufferMessage: BufferMessage => + try { + logDebug("Handling as a buffer message " + bufferMessage) + val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) + logDebug("Parsed as a block message array") + val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) + Some(new BlockMessageArray(responseMessages).toBufferMessage) + } catch { + case e: Exception => { + logError("Exception handling buffer message", e) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } + } + + case otherMessage: Any => + logError("Unknown type message received: " + otherMessage) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } + } + + private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { + blockMessage.getType match { + case BlockMessage.TYPE_PUT_BLOCK => + val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) + logDebug("Received [" + msg + "]") + putBlock(msg.id.toString, msg.data, msg.level) + None + + case BlockMessage.TYPE_GET_BLOCK => + val msg = new GetBlock(blockMessage.getId) + logDebug("Received [" + msg + "]") + val buffer = getBlock(msg.id.toString) + if (buffer == null) { + return None + } + Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer))) + + case _ => None + } + } + + private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + val startTimeMs = System.currentTimeMillis() + logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) + blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level) + logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + + " with data size: " + bytes.limit) + } + + private def getBlock(blockId: String): ByteBuffer = { + val startTimeMs = System.currentTimeMillis() + logDebug("GetBlock " + blockId + " started from " + startTimeMs) + val buffer = blockDataManager.getBlockData(blockId).orNull + logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + + " and got buffer " + buffer) + buffer.nioByteBuffer() + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala new file mode 100644 index 0000000..747a208 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.nio.ByteBuffer + +import scala.collection.mutable.{ArrayBuffer, StringBuilder} + +import org.apache.spark._ + +/** + * SecurityMessage is class that contains the connectionId and sasl token + * used in SASL negotiation. SecurityMessage has routines for converting + * it to and from a BufferMessage so that it can be sent by the ConnectionManager + * and easily consumed by users when received. + * The api was modeled after BlockMessage. + * + * The connectionId is the connectionId of the client side. Since + * message passing is asynchronous and its possible for the server side (receiving) + * to get multiple different types of messages on the same connection the connectionId + * is used to know which connnection the security message is intended for. + * + * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side + * is acting as a client and connecting to node_1. SASL negotiation has to occur + * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. + * node_1 receives the message from node_0 but before it can process it and send a response, + * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 + * and sends a security message of its own to authenticate as a client. Now node_0 gets + * the message and it needs to decide if this message is in response to it being a client + * (from the first send) or if its just node_1 trying to connect to it to send data. This + * is where the connectionId field is used. node_0 can lookup the connectionId to see if + * it is in response to it being a client or if its in response to someone sending other data. + * + * The format of a SecurityMessage as its sent is: + * - Length of the ConnectionId + * - ConnectionId + * - Length of the token + * - Token + */ +private[nio] class SecurityMessage extends Logging { + + private var connectionId: String = null + private var token: Array[Byte] = null + + def set(byteArr: Array[Byte], newconnectionId: String) { + if (byteArr == null) { + token = new Array[Byte](0) + } else { + token = byteArr + } + connectionId = newconnectionId + } + + /** + * Read the given buffer and set the members of this class. + */ + def set(buffer: ByteBuffer) { + val idLength = buffer.getInt() + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buffer.getChar() + } + connectionId = idBuilder.toString() + + val tokenLength = buffer.getInt() + token = new Array[Byte](tokenLength) + if (tokenLength > 0) { + buffer.get(token, 0, tokenLength) + } + } + + def set(bufferMsg: BufferMessage) { + val buffer = bufferMsg.buffers.apply(0) + buffer.clear() + set(buffer) + } + + def getConnectionId: String = { + return connectionId + } + + def getToken: Array[Byte] = { + return token + } + + /** + * Create a BufferMessage that can be sent by the ConnectionManager containing + * the security information from this class. + * @return BufferMessage + */ + def toBufferMessage: BufferMessage = { + val buffers = new ArrayBuffer[ByteBuffer]() + + // 4 bytes for the length of the connectionId + // connectionId is of type char so multiple the length by 2 to get number of bytes + // 4 bytes for the length of token + // token is a byte buffer so just take the length + var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) + buffer.putInt(connectionId.length()) + connectionId.foreach((x: Char) => buffer.putChar(x)) + buffer.putInt(token.length) + + if (token.length > 0) { + buffer.put(token) + } + buffer.flip() + buffers += buffer + + var message = Message.createBufferMessage(buffers) + logDebug("message total size is : " + message.size) + message.isSecurityNeg = true + return message + } + + override def toString: String = { + "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" + } +} + +private[nio] object SecurityMessage { + + /** + * Convert the given BufferMessage to a SecurityMessage by parsing the contents + * of the BufferMessage and populating the SecurityMessage fields. + * @param bufferMessage is a BufferMessage that was received + * @return new SecurityMessage + */ + def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(bufferMessage) + newSecurityMessage + } + + /** + * Create a SecurityMessage to send from a given saslResponse. + * @param response is the response to a challenge from the SaslClient or Saslserver + * @param connectionId the client connectionId we are negotiation authentication for + * @return a new SecurityMessage + */ + def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(response, connectionId) + newSecurityMessage + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 87ef9bb..d6386f8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,9 +27,9 @@ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.spark._ import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ -import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.CompactBuffer http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 96faccc..439981d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -26,6 +26,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.{SparkEnv, SparkConf, Logging} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ @@ -166,34 +167,30 @@ class FileShuffleBlockManager(conf: SparkConf) } } - /** - * Returns the physical file segment in which the given BlockId is located. - */ - private def getBlockLocation(id: ShuffleBlockId): FileSegment = { + override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { + val segment = getBlockData(blockId) + Some(segment.nioByteBuffer()) + } + + override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { if (consolidateShuffleFiles) { // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(id.shuffleId) + val shuffleState = shuffleStates(blockId.shuffleId) val iter = shuffleState.allFileGroups.iterator while (iter.hasNext) { - val segment = iter.next.getFileSegmentFor(id.mapId, id.reduceId) - if (segment.isDefined) { return segment.get } + val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) + if (segmentOpt.isDefined) { + val segment = segmentOpt.get + return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length) + } } - throw new IllegalStateException("Failed to find shuffle block: " + id) + throw new IllegalStateException("Failed to find shuffle block: " + blockId) } else { - val file = blockManager.diskBlockManager.getFile(id) - new FileSegment(file, 0, file.length()) + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(file, 0, file.length) } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - val segment = getBlockLocation(blockId) - blockManager.diskStore.getBytes(segment) - } - - override def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer] = { - Left(getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])) - } - /** Remove all the blocks / files and metadata related to a particular shuffle. */ def removeShuffle(shuffleId: ShuffleId): Boolean = { // Do not change the ordering of this, if shuffleStates should be removed only http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 8bb9efc..4ab3433 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -21,6 +21,7 @@ import java.io._ import java.nio.ByteBuffer import org.apache.spark.SparkEnv +import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer} import org.apache.spark.storage._ /** @@ -89,10 +90,11 @@ class IndexShuffleBlockManager extends ShuffleBlockManager { } } - /** - * Get the location of a block in a map output file. Uses the index file we create for it. - * */ - private def getBlockLocation(blockId: ShuffleBlockId): FileSegment = { + override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { + Some(getBlockData(blockId).nioByteBuffer()) + } + + override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) @@ -102,20 +104,14 @@ class IndexShuffleBlockManager extends ShuffleBlockManager { in.skip(blockId.reduceId * 8) val offset = in.readLong() val nextOffset = in.readLong() - new FileSegment(getDataFile(blockId.shuffleId, blockId.mapId), offset, nextOffset - offset) + new FileSegmentManagedBuffer( + getDataFile(blockId.shuffleId, blockId.mapId), + offset, + nextOffset - offset) } finally { in.close() } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - val segment = getBlockLocation(blockId) - blockManager.diskStore.getBytes(segment) - } - - override def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer] = { - Left(getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])) - } - override def stop() = {} } http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala index 4240580..63863cc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala @@ -19,7 +19,8 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer -import org.apache.spark.storage.{FileSegment, ShuffleBlockId} +import org.apache.spark.network.ManagedBuffer +import org.apache.spark.storage.ShuffleBlockId private[spark] trait ShuffleBlockManager { @@ -31,8 +32,7 @@ trait ShuffleBlockManager { */ def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] - def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer] + def getBlockData(blockId: ShuffleBlockId): ManagedBuffer def stop(): Unit } - http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 12b4756..6cf9305 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -21,10 +21,9 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { @@ -32,8 +31,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer, - shuffleMetrics: ShuffleReadMetrics) + serializer: Serializer) : Iterator[T] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) @@ -74,7 +72,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics) + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + SparkEnv.get.blockTransferService, + blockManager, + blocksByAddress, + serializer, + SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 7bed97a..88a5f1e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -36,10 +36,8 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser, - readMetrics) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala deleted file mode 100644 index e35b7fe..0000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ /dev/null @@ -1,254 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.concurrent.LinkedBlockingQueue -import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue -import scala.util.{Failure, Success} - -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.executor.ShuffleReadMetrics -import org.apache.spark.network.BufferMessage -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils - -/** - * A block fetcher iterator interface for fetching shuffle blocks. - */ -private[storage] -trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { - def initialize() -} - - -private[storage] -object BlockFetcherIterator { - - /** - * A request to fetch blocks from a remote BlockManager. - * @param address remote BlockManager to fetch from. - * @param blocks Sequence of tuple, where the first element is the block id, - * and the second element is the estimated size, used to calculate bytesInFlight. - */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { - val size = blocks.map(_._2).sum - } - - /** - * Result of a fetch from a remote block. A failure is represented as size == -1. - * @param blockId block id - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. - */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - - // TODO: Refactor this whole thing to make code more reusable. - class BasicBlockFetcherIterator( - private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics) - extends BlockFetcherIterator { - - import blockManager._ - - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - - // Total number blocks fetched (local + remote). Also number of FetchResults expected - protected var _numBlocksToFetch = 0 - - protected var startTime = System.currentTimeMillis - - // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks - protected val localBlocksToFetch = new ArrayBuffer[BlockId]() - - // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks - protected val remoteBlocksToFetch = new HashSet[BlockId]() - - // A queue to hold our results. - protected val results = new LinkedBlockingQueue[FetchResult] - - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight - protected val fetchRequests = new Queue[FetchRequest] - - // Current bytes in flight from our requests - protected var bytesInFlight = 0L - - protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val blockMessageArray = new BlockMessageArray(req.blocks.map { - case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) - }) - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onComplete { - case Success(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can - // be incrementing bytes read at the same time (SPARK-2625). - readMetrics.remoteBytesRead += networkSize - readMetrics.remoteBlocksFetched += 1 - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - case Failure(exception) => { - logError("Could not get block(s) from " + cmId, exception) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - } - } - - protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { - // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) - - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - var totalBlocks = 0 - for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size - if (address == blockManagerId) { - // Filter out zero-sized blocks - localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) - _numBlocksToFetch += localBlocksToFetch.size - } else { - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(BlockId, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { - curBlocks += ((blockId, size)) - remoteBlocksToFetch += blockId - _numBlocksToFetch += 1 - curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) - } - if (curRequestSize >= targetRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curBlocks = new ArrayBuffer[(BlockId, Long)] - logDebug(s"Creating fetch request of $curRequestSize at $address") - curRequestSize = 0 - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " + - totalBlocks + " blocks") - remoteRequests - } - - protected def getLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocksToFetch) { - try { - readMetrics.localBlocksFetched += 1 - results.put(new FetchResult(id, 0, () => getLocalShuffleFromDisk(id, serializer).get)) - logDebug("Got local block " + id) - } catch { - case e: Exception => { - logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) - return - } - } - } - } - - override def initialize() { - // Split local and remote blocks. - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - - val numFetches = remoteRequests.size - fetchRequests.size - logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue - // as they arrive. - @volatile protected var resultsGotten = 0 - - override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - - override def next(): (BlockId, Option[Iterator[Any]]) = { - resultsGotten += 1 - val startFetchWait = System.currentTimeMillis() - val result = results.take() - val stopFetchWait = System.currentTimeMillis() - readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) - if (! result.failed) bytesInFlight -= result.size - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - } - // End of BasicBlockFetcherIterator -} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
