Repository: spark
Updated Branches:
  refs/heads/master 5db78e6b8 -> 127e97bee


[SPARK-3632] ConnectionManager can run out of receive threads with 
authentication on

If you turn authentication on and you are using a lot of executors. There is a 
chance that all the of the threads in the handleMessageExecutor could be 
waiting to send a message because they are blocked waiting on authentication to 
happen. This can cause a temporary deadlock until the connection times out.

To fix it, I got rid of the wait/notify and use a single outbox but only send 
security messages from it until authentication has completed.

Author: Thomas Graves <[email protected]>

Closes #2484 from tgravescs/cm_threads_auth and squashes the following commits:

a0a961d [Thomas Graves] give it a type
b6bc80b [Thomas Graves] Rework comments
d6d4175 [Thomas Graves] update from comments
081b765 [Thomas Graves] cleanup
4d7f8f5 [Thomas Graves] Change to not use wait/notify while waiting for 
authentication


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/127e97be
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/127e97be
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/127e97be

Branch: refs/heads/master
Commit: 127e97bee1e6aae7b70263bc5944b7be6f4e6fea
Parents: 5db78e6
Author: Thomas Graves <[email protected]>
Authored: Thu Oct 2 13:52:54 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Thu Oct 2 13:52:54 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/SecurityManager.scala      |  5 +-
 .../apache/spark/network/nio/Connection.scala   | 65 ++++++++++++------
 .../spark/network/nio/ConnectionManager.scala   | 72 +++++---------------
 3 files changed, 63 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/127e97be/core/src/main/scala/org/apache/spark/SecurityManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala 
b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 3832a78..0e0f1a7 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil
  *            and a Server, so for a particular connection is has to determine 
what to do.
  *            A ConnectionId was added to be able to track connections and is 
used to
  *            match up incoming messages with connections waiting for 
authentication.
- *            If its acting as a client and trying to send a message to 
another ConnectionManager,
- *            it blocks the thread calling sendMessage until the SASL 
negotiation has occurred.
  *            The ConnectionManager tracks all the sendingConnections using 
the ConnectionId
- *            and waits for the response from the server and does the 
handshake.
+ *            and waits for the response from the server and does the 
handshake before sending
+ *            the real message.
  *
  *  - HTTP for the Spark UI -> the UI was changed to use servlets so that 
javax servlet filters
  *            can be used. Yarn requires a specific AmIpFilter be installed 
for security to work

http://git-wip-us.apache.org/repos/asf/spark/blob/127e97be/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala 
b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 18172d3..f368209 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -20,23 +20,27 @@ package org.apache.spark.network.nio
 import java.net._
 import java.nio._
 import java.nio.channels._
+import java.util.LinkedList
 
 import org.apache.spark._
 
-import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
+import scala.collection.mutable.{ArrayBuffer, HashMap}
 
 private[nio]
 abstract class Connection(val channel: SocketChannel, val selector: Selector,
-    val socketRemoteConnectionManagerId: ConnectionManagerId, val 
connectionId: ConnectionId)
+    val socketRemoteConnectionManagerId: ConnectionManagerId, val 
connectionId: ConnectionId,
+    val securityMgr: SecurityManager)
   extends Logging {
 
   var sparkSaslServer: SparkSaslServer = null
   var sparkSaslClient: SparkSaslClient = null
 
-  def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) 
= {
+  def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId,
+      securityMgr_ : SecurityManager) = {
     this(channel_, selector_,
       ConnectionManagerId.fromSocketAddress(
-        
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
+        
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]),
+        id_, securityMgr_)
   }
 
   channel.configureBlocking(false)
@@ -52,14 +56,6 @@ abstract class Connection(val channel: SocketChannel, val 
selector: Selector,
 
   val remoteAddress = getRemoteAddress()
 
-  /**
-   * Used to synchronize client requests: client's work-related requests must
-   * wait until SASL authentication completes.
-   */
-  private val authenticated = new Object()
-
-  def getAuthenticated(): Object = authenticated
-
   def isSaslComplete(): Boolean
 
   def resetForceReregister(): Boolean
@@ -192,22 +188,22 @@ abstract class Connection(val channel: SocketChannel, val 
selector: Selector,
 
 private[nio]
 class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
-    remoteId_ : ConnectionManagerId, id_ : ConnectionId)
-  extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
+    remoteId_ : ConnectionManagerId, id_ : ConnectionId,
+    securityMgr_ : SecurityManager)
+  extends Connection(SocketChannel.open, selector_, remoteId_, id_, 
securityMgr_) {
 
   def isSaslComplete(): Boolean = {
     if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
   }
 
   private class Outbox {
-    val messages = new Queue[Message]()
+    val messages = new LinkedList[Message]()
     val defaultChunkSize = 65536
     var nextMessageToBeUsed = 0
 
     def addMessage(message: Message) {
       messages.synchronized {
-        /* messages += message */
-        messages.enqueue(message)
+        messages.add(message)
         logDebug("Added [" + message + "] to outbox for sending to " +
           "[" + getRemoteConnectionManagerId() + "]")
       }
@@ -218,10 +214,27 @@ class SendingConnection(val address: InetSocketAddress, 
selector_ : Selector,
         while (!messages.isEmpty) {
           /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
           /* val message = messages(nextMessageToBeUsed) */
-          val message = messages.dequeue()
+
+          val message = if (securityMgr.isAuthenticationEnabled() && 
!isSaslComplete()) {
+            // only allow sending of security messages until sasl is complete
+            var pos = 0
+            var securityMsg: Message = null
+            while (pos < messages.size() && securityMsg == null) {
+              if (messages.get(pos).isSecurityNeg) {
+                securityMsg = messages.remove(pos)
+              }
+              pos = pos + 1
+            }
+            // didn't find any security messages and auth isn't completed so 
return
+            if (securityMsg == null) return None
+            securityMsg
+          } else {
+            messages.removeFirst()
+          }
+
           val chunk = message.getChunkForSending(defaultChunkSize)
           if (chunk.isDefined) {
-            messages.enqueue(message)
+            messages.add(message)
             nextMessageToBeUsed = nextMessageToBeUsed + 1
             if (!message.started) {
               logDebug(
@@ -273,6 +286,15 @@ class SendingConnection(val address: InetSocketAddress, 
selector_ : Selector,
     changeConnectionKeyInterest(DEFAULT_INTEREST)
   }
 
+  def registerAfterAuth(): Unit = {
+    outbox.synchronized {
+      needForceReregister = true
+    }
+    if (channel.isConnected) {
+      registerInterest()
+    }
+  }
+
   def send(message: Message) {
     outbox.synchronized {
       outbox.addMessage(message)
@@ -415,8 +437,9 @@ class SendingConnection(val address: InetSocketAddress, 
selector_ : Selector,
 private[spark] class ReceivingConnection(
     channel_ : SocketChannel,
     selector_ : Selector,
-    id_ : ConnectionId)
-    extends Connection(channel_, selector_, id_) {
+    id_ : ConnectionId,
+    securityMgr_ : SecurityManager)
+    extends Connection(channel_, selector_, id_, securityMgr_) {
 
   def isSaslComplete(): Boolean = {
     if (sparkSaslServer != null) sparkSaslServer.isComplete() else false

http://git-wip-us.apache.org/repos/asf/spark/blob/127e97be/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala 
b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 5aa7e94..01cd27a 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -32,7 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, 
Promise}
 import scala.language.postfixOps
 
 import org.apache.spark._
-import org.apache.spark.util.{SystemClock, Utils}
+import org.apache.spark.util.Utils
 
 
 private[nio] class ConnectionManager(
@@ -65,8 +65,6 @@ private[nio] class ConnectionManager(
   private val selector = SelectorProvider.provider.openSelector()
   private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
 
-  // default to 30 second timeout waiting for authentication
-  private val authTimeout = 
conf.getInt("spark.core.connection.auth.wait.timeout", 30)
   private val ackTimeout = 
conf.getInt("spark.core.connection.ack.wait.timeout", 60)
 
   private val handleMessageExecutor = new ThreadPoolExecutor(
@@ -409,7 +407,8 @@ private[nio] class ConnectionManager(
     while (newChannel != null) {
       try {
         val newConnectionId = new ConnectionId(id, 
idCount.getAndIncrement.intValue)
-        val newConnection = new ReceivingConnection(newChannel, selector, 
newConnectionId)
+        val newConnection = new ReceivingConnection(newChannel, selector, 
newConnectionId,
+          securityManager)
         newConnection.onReceive(receiveMessage)
         addListeners(newConnection)
         addConnection(newConnection)
@@ -527,9 +526,8 @@ private[nio] class ConnectionManager(
     if (waitingConn.isSaslComplete()) {
       logDebug("Client sasl completed for id: "  + waitingConn.connectionId)
       connectionsAwaitingSasl -= waitingConn.connectionId
-      waitingConn.getAuthenticated().synchronized {
-        waitingConn.getAuthenticated().notifyAll()
-      }
+      waitingConn.registerAfterAuth()
+      wakeupSelector()
       return
     } else {
       var replyToken : Array[Byte] = null
@@ -538,9 +536,8 @@ private[nio] class ConnectionManager(
         if (waitingConn.isSaslComplete()) {
           logDebug("Client sasl completed after evaluate for id: " + 
waitingConn.connectionId)
           connectionsAwaitingSasl -= waitingConn.connectionId
-          waitingConn.getAuthenticated().synchronized {
-            waitingConn.getAuthenticated().notifyAll()
-          }
+          waitingConn.registerAfterAuth()
+          wakeupSelector()
           return
         }
         val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@@ -574,9 +571,11 @@ private[nio] class ConnectionManager(
         }
         replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
         if (connection.isSaslComplete()) {
-          logDebug("Server sasl completed: " + connection.connectionId)
+          logDebug("Server sasl completed: " + connection.connectionId +
+            " for: " + connectionId)
         } else {
-          logDebug("Server sasl not completed: " + connection.connectionId)
+          logDebug("Server sasl not completed: " + connection.connectionId +
+            " for: " + connectionId)
         }
         if (replyToken != null) {
           val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@@ -723,7 +722,8 @@ private[nio] class ConnectionManager(
             if (message == null) throw new Exception("Error creating security 
message")
             connectionsAwaitingSasl += ((conn.connectionId, conn))
             sendSecurityMessage(connManagerId, message)
-            logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+            logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId 
+
+              " to: " + connManagerId)
           } catch {
             case e: Exception => {
               logError("Error getting first response from the SaslClient.", e)
@@ -744,7 +744,7 @@ private[nio] class ConnectionManager(
       val inetSocketAddress = new InetSocketAddress(connManagerId.host, 
connManagerId.port)
       val newConnectionId = new ConnectionId(id, 
idCount.getAndIncrement.intValue)
       val newConnection = new SendingConnection(inetSocketAddress, selector, 
connManagerId,
-        newConnectionId)
+        newConnectionId, securityManager)
       logInfo("creating new sending connection for security! " + 
newConnectionId )
       registerRequests.enqueue(newConnection)
 
@@ -769,61 +769,23 @@ private[nio] class ConnectionManager(
         connectionManagerId.port)
       val newConnectionId = new ConnectionId(id, 
idCount.getAndIncrement.intValue)
       val newConnection = new SendingConnection(inetSocketAddress, selector, 
connectionManagerId,
-        newConnectionId)
+        newConnectionId, securityManager)
       logTrace("creating new sending connection: " + newConnectionId)
       registerRequests.enqueue(newConnection)
 
       newConnection
     }
     val connection = connectionsById.getOrElseUpdate(connectionManagerId, 
startNewConnection())
-    if (authEnabled) {
-      checkSendAuthFirst(connectionManagerId, connection)
-    }
+
     message.senderAddress = id.toSocketAddress()
     logDebug("Before Sending [" + message + "] to [" + connectionManagerId + 
"]" + " " +
       "connectionid: "  + connection.connectionId)
 
     if (authEnabled) {
-      // if we aren't authenticated yet lets block the senders until 
authentication completes
-      try {
-        connection.getAuthenticated().synchronized {
-          val clock = SystemClock
-          val startTime = clock.getTime()
-
-          while (!connection.isSaslComplete()) {
-            logDebug("getAuthenticated wait connectionid: " + 
connection.connectionId)
-            // have timeout in case remote side never responds
-            connection.getAuthenticated().wait(500)
-            if (((clock.getTime() - startTime) >= (authTimeout * 1000))
-              && (!connection.isSaslComplete())) {
-              // took to long to authenticate the connection, something 
probably went wrong
-              throw new Exception("Took to long for authentication to " + 
connectionManagerId +
-                ", waited " + authTimeout + "seconds, failing.")
-            }
-          }
-        }
-      } catch {
-        case e: Exception => logError("Exception while waiting for 
authentication.", e)
-
-        // need to tell sender it failed
-        messageStatuses.synchronized {
-          val s = messageStatuses.get(message.id)
-          s match {
-            case Some(msgStatus) => {
-              messageStatuses -= message.id
-              logInfo("Notifying " + msgStatus.connectionManagerId)
-              msgStatus.markDone(None)
-            }
-            case None => {
-              logError("no messageStatus for failed message id: " + message.id)
-            }
-          }
-        }
-      }
+      checkSendAuthFirst(connectionManagerId, connection)
     }
     logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
     connection.send(message)
-
     wakeupSelector()
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to