This is an automated email from the ASF dual-hosted git repository.

wirebaron pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/geode.git


The following commit(s) were added to refs/heads/develop by this push:
     new 2f81f40  GEODE-4094: ClientHealthMonitor may cause an NPE in a 
ServerConnection (#1326)
2f81f40 is described below

commit 2f81f40727cfd02296e9e0f041b6ae2bb9cd1a45
Author: Brian Rowe <[email protected]>
AuthorDate: Fri Jan 26 14:22:46 2018 -0800

    GEODE-4094: ClientHealthMonitor may cause an NPE in a ServerConnection 
(#1326)
    
    * GEODE-4094: ClientHealthMonitor may cause an NPE in a ServerConnection
    
    - minor refactoring of AcceptorImpl and Handshake to improve testability
    - added a unit test to demonstrate race condition
    - refactored connection map into a new object to prevent race
---
 .../internal/cache/tier/sockets/AcceptorImpl.java  |   7 -
 .../cache/tier/sockets/ClientHealthMonitor.java    | 253 +++++++++------------
 .../internal/cache/tier/sockets/HandShake.java     |   2 +-
 .../cache/tier/sockets/ServerConnection.java       |  57 +++--
 .../tier/sockets/ServerConnectionCollection.java   |  44 ++++
 .../tier/sockets/ServerHandShakeProcessor.java     |   8 +-
 .../cache/tier/sockets/ServerConnectionTest.java   | 139 ++++++++++-
 7 files changed, 318 insertions(+), 192 deletions(-)

diff --git 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/AcceptorImpl.java
 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/AcceptorImpl.java
index 360fa48..010e07c 100755
--- 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/AcceptorImpl.java
+++ 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/AcceptorImpl.java
@@ -301,7 +301,6 @@ public class AcceptorImpl implements Acceptor, Runnable, 
CommBufferPool {
   private long acceptorId;
 
   private static boolean isAuthenticationRequired;
-  private static boolean isIntegratedSecurity;
 
   private static boolean isPostAuthzCallbackPresent;
 
@@ -547,8 +546,6 @@ public class AcceptorImpl implements Acceptor, Runnable, 
CommBufferPool {
 
     isAuthenticationRequired = this.securityService.isClientSecurityRequired();
 
-    isIntegratedSecurity = this.securityService.isIntegratedSecurity();
-
     String postAuthzFactoryName =
         
this.cache.getDistributedSystem().getProperties().getProperty(SECURITY_CLIENT_ACCESSOR_PP);
 
@@ -1784,10 +1781,6 @@ public class AcceptorImpl implements Acceptor, Runnable, 
CommBufferPool {
     return isAuthenticationRequired;
   }
 
-  public static boolean isIntegratedSecurity() {
-    return isIntegratedSecurity;
-  }
-
   public static boolean isPostAuthzCallbackPresent() {
     return isPostAuthzCallbackPresent;
   }
diff --git 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ClientHealthMonitor.java
 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ClientHealthMonitor.java
index afefb69..d08902d 100644
--- 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ClientHealthMonitor.java
+++ 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ClientHealthMonitor.java
@@ -25,7 +25,6 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicIntegerArray;
 
-import org.apache.commons.lang.StringUtils;
 import org.apache.logging.log4j.Logger;
 
 import org.apache.geode.CancelException;
@@ -70,16 +69,6 @@ public class ClientHealthMonitor {
   private final Object _clientHeartbeatsLock = new Object();
 
   /**
-   * The map of known client threads
-   */
-  private final Map _clientThreads;
-
-  /**
-   * An object used to lock the map of client threads
-   */
-  private final Object _clientThreadsLock = new Object();
-
-  /**
    * THe GemFire <code>Cache</code>
    */
   private final InternalCache _cache;
@@ -124,6 +113,12 @@ public class ClientHealthMonitor {
   private final HashMap cleanupProxyIdTable = new HashMap();
 
   /**
+   * Used to track the connections for a particular client
+   */
+  private final HashMap<ClientProxyMembershipID, ServerConnectionCollection> 
proxyIdConnections =
+      new HashMap<>();
+
+  /**
    * Gives, version-wise, the number of clients connected to the cache servers 
in this cache, which
    * are capable of processing recieved deltas.
    *
@@ -349,18 +344,12 @@ public class ClientHealthMonitor {
    * @param proxyID The membership id of the client to be updated
    * @param connection The thread processing client requests
    */
-  public void addConnection(ClientProxyMembershipID proxyID, ServerConnection 
connection) {
-    // logger.info("ClientHealthMonitor: Adding " + connection + " to
-    // client with member id " + proxyID);
-    synchronized (_clientThreadsLock) {
-      Set serverConnections = (Set) this._clientThreads.get(proxyID);
-      if (serverConnections == null) {
-        serverConnections = new HashSet();
-        this._clientThreads.put(proxyID, serverConnections);
-      }
-      serverConnections.add(connection);
-      // logger.info("ClientHealthMonitor: The client with member id " +
-      // proxyID + " contains " + serverConnections.size() + " threads");
+  public ServerConnectionCollection addConnection(ClientProxyMembershipID 
proxyID,
+      ServerConnection connection) {
+    synchronized (proxyIdConnections) {
+      ServerConnectionCollection collection = getProxyIdCollection(proxyID);
+      collection.addConnection(connection);
+      return collection;
     }
   }
 
@@ -371,18 +360,12 @@ public class ClientHealthMonitor {
    * @param connection The thread processing client requests
    */
   public void removeConnection(ClientProxyMembershipID proxyID, 
ServerConnection connection) {
-    // logger.info("ClientHealthMonitor: Removing " + connection + " from
-    // client with member id " + proxyID);
-    synchronized (_clientThreadsLock) {
-      Set serverConnections = (Set) this._clientThreads.get(proxyID);
-      if (serverConnections != null) { // fix for bug 35343
-        serverConnections.remove(connection);
-        // logger.info("ClientHealthMonitor: The client with member id " +
-        // proxyID + " contains " + serverConnections.size() + " threads");
-        if (serverConnections.isEmpty()) {
-          // logger.info("ClientHealthMonitor: The client with member id "
-          // + proxyID + " is being removed since it contains 0 threads");
-          this._clientThreads.remove(proxyID);
+    synchronized (proxyIdConnections) {
+      ServerConnectionCollection collection = proxyIdConnections.get(proxyID);
+      if (collection != null) {
+        collection.removeConnection(connection);
+        if (collection.getConnections().isEmpty()) {
+          proxyIdConnections.remove(proxyID);
         }
       }
     }
@@ -419,24 +402,21 @@ public class ClientHealthMonitor {
    *        ConnectionProxies may be from same client member or different. If 
it is null this would
    *        mean to fetch the Connections of all the ConnectionProxy objects.
    */
-  public Map getConnectedClients(Set filterProxies) {
-    Map map = new HashMap(); // KEY=proxyID, VALUE=connectionCount (Integer)
-    synchronized (_clientThreadsLock) {
-      Iterator connectedClients = this._clientThreads.entrySet().iterator();
-      while (connectedClients.hasNext()) {
-        Map.Entry entry = (Map.Entry) connectedClients.next();
-        ClientProxyMembershipID proxyID = (ClientProxyMembershipID) 
entry.getKey();// proxyID
-                                                                               
    // includes FQDN
+  public Map<String, Object[]> getConnectedClients(Set filterProxies) {
+    Map<String, Object[]> map = new HashMap<>(); // KEY=proxyID, 
VALUE=connectionCount (Integer)
+    synchronized (proxyIdConnections) {
+      for (Map.Entry<ClientProxyMembershipID, ServerConnectionCollection> 
entry : proxyIdConnections
+          .entrySet()) {
+        ClientProxyMembershipID proxyID = entry.getKey();// proxyID
+        // includes FQDN
         if (filterProxies == null || filterProxies.contains(proxyID)) {
           String membershipID = null;
-          Set connections = (Set) entry.getValue();
+          Set<ServerConnection> connections = 
entry.getValue().getConnections();
           int socketPort = 0;
           InetAddress socketAddress = null;
           /// *
-          Iterator serverConnections = connections.iterator();
           // Get data from one.
-          while (serverConnections.hasNext()) {
-            ServerConnection sc = (ServerConnection) serverConnections.next();
+          for (ServerConnection sc : connections) {
             socketPort = sc.getSocketPort();
             socketAddress = sc.getSocketAddress();
             membershipID = sc.getMembershipID();
@@ -453,7 +433,7 @@ public class ClientHealthMonitor {
                 + " client member id=" + membershipID;
           }
           Object[] data = null;
-          data = (Object[]) map.get(membershipID);
+          data = map.get(membershipID);
           if (data == null) {
             map.put(membershipID, new Object[] {clientString, 
Integer.valueOf(connectionCount)});
           } else {
@@ -480,20 +460,17 @@ public class ClientHealthMonitor {
    *
    * @return Map of ClientProxyMembershipID against CacheClientStatus objects.
    */
-  public Map getStatusForAllClients() {
-    Map result = new HashMap();
-    synchronized (_clientThreadsLock) {
-      Iterator connectedClients = this._clientThreads.entrySet().iterator();
-      while (connectedClients.hasNext()) {
-        Map.Entry entry = (Map.Entry) connectedClients.next();
-        ClientProxyMembershipID proxyID = (ClientProxyMembershipID) 
entry.getKey();
+  public Map<ClientProxyMembershipID, CacheClientStatus> 
getStatusForAllClients() {
+    Map<ClientProxyMembershipID, CacheClientStatus> result = new HashMap<>();
+    synchronized (proxyIdConnections) {
+      for (Map.Entry<ClientProxyMembershipID, ServerConnectionCollection> 
entry : proxyIdConnections
+          .entrySet()) {
+        ClientProxyMembershipID proxyID = entry.getKey();
         CacheClientStatus cci = new CacheClientStatus(proxyID);
-        Set connections = (Set) this._clientThreads.get(proxyID);
+        Set<ServerConnection> connections = entry.getValue().getConnections();
         if (connections != null) {
           String memberId = null;
-          Iterator connectionsIterator = connections.iterator();
-          while (connectionsIterator.hasNext()) {
-            ServerConnection sc = (ServerConnection) 
connectionsIterator.next();
+          for (ServerConnection sc : connections) {
             if (sc.isClientServerConnection()) {
               memberId = sc.getMembershipID(); // each ServerConnection has 
the same member id
               cci.setMemberId(memberId);
@@ -508,30 +485,27 @@ public class ClientHealthMonitor {
     return result;
   }
 
-  public void fillInClientInfo(Map allClients) {
+  public void fillInClientInfo(Map<ClientProxyMembershipID, CacheClientStatus> 
allClients) {
     // The allClients parameter includes only actual clients (not remote
     // gateways). This monitor will include remote gateway connections,
     // so weed those out.
-    synchronized (_clientThreadsLock) {
-      Iterator allClientsIterator = allClients.entrySet().iterator();
-      while (allClientsIterator.hasNext()) {
-        Map.Entry entry = (Map.Entry) allClientsIterator.next();
-        ClientProxyMembershipID proxyID = (ClientProxyMembershipID) 
entry.getKey();// proxyID
-                                                                               
    // includes FQDN
-        CacheClientStatus cci = (CacheClientStatus) entry.getValue();
-        Set connections = (Set) this._clientThreads.get(proxyID);
+    synchronized (proxyIdConnections) {
+      for (Map.Entry<ClientProxyMembershipID, CacheClientStatus> entry : 
allClients.entrySet()) {
+        ClientProxyMembershipID proxyID = entry.getKey();// proxyID
+        // includes FQDN
+        CacheClientStatus cci = entry.getValue();
+        ServerConnectionCollection collection = 
proxyIdConnections.get(proxyID);
+        Set<ServerConnection> connections = collection != null ? 
collection.getConnections() : null;
         if (connections != null) {
           String memberId = null;
           cci.setNumberOfConnections(connections.size());
           List socketPorts = new ArrayList();
           List socketAddresses = new ArrayList();
-          Iterator connectionsIterator = connections.iterator();
-          while (connectionsIterator.hasNext()) {
-            ServerConnection sc = (ServerConnection) 
connectionsIterator.next();
+          for (ServerConnection sc : connections) {
             socketPorts.add(Integer.valueOf(sc.getSocketPort()));
             socketAddresses.add(sc.getSocketAddress());
             memberId = sc.getMembershipID(); // each ServerConnection has the
-                                             // same member id
+            // same member id
           }
           cci.setMemberId(memberId);
           cci.setSocketPorts(socketPorts);
@@ -541,17 +515,14 @@ public class ClientHealthMonitor {
     }
   }
 
-  public Map getConnectedIncomingGateways() {
-    Map connectedIncomingGateways = new HashMap();
-    synchronized (_clientThreadsLock) {
-      Iterator connectedClients = this._clientThreads.entrySet().iterator();
-      while (connectedClients.hasNext()) {
-        Map.Entry entry = (Map.Entry) connectedClients.next();
-        ClientProxyMembershipID proxyID = (ClientProxyMembershipID) 
entry.getKey();
-        Set connections = (Set) entry.getValue();
-        Iterator connectionsIterator = connections.iterator();
-        while (connectionsIterator.hasNext()) {
-          ServerConnection sc = (ServerConnection) connectionsIterator.next();
+  public Map<String, IncomingGatewayStatus> getConnectedIncomingGateways() {
+    Map<String, IncomingGatewayStatus> connectedIncomingGateways = new 
HashMap<>();
+    synchronized (proxyIdConnections) {
+      for (Map.Entry<ClientProxyMembershipID, ServerConnectionCollection> 
entry : proxyIdConnections
+          .entrySet()) {
+        ClientProxyMembershipID proxyID = entry.getKey();
+        Set<ServerConnection> connections = entry.getValue().getConnections();
+        for (ServerConnection sc : connections) {
           if (sc.getCommunicationMode().isWAN()) {
             IncomingGatewayStatus status = new 
IncomingGatewayStatus(proxyID.getDSMembership(),
                 sc.getSocketAddress(), sc.getSocketPort());
@@ -566,19 +537,17 @@ public class ClientHealthMonitor {
   protected boolean cleanupClientThreads(ClientProxyMembershipID proxyID, 
boolean timedOut) {
     boolean result = false;
     Set serverConnections = null;
-    synchronized (this._clientThreadsLock) {
-      serverConnections = (Set) this._clientThreads.remove(proxyID);
-      // It is ok to modify the set after releasing the sync
-      // because it has been removed from the map while holding
-      // the sync.
-    } // end sync here to fix bug 37576 and 36740
+    synchronized (proxyIdConnections) {
+      ServerConnectionCollection collection = 
proxyIdConnections.remove(proxyID);
+      if (collection != null) {
+        serverConnections = collection.getConnections();
+      }
+    }
     {
       if (serverConnections != null) { // fix for bug 35343
         result = true;
-        // logger.warn("Terminating " + serverConnections.size() + " 
connections");
         for (Iterator it = serverConnections.iterator(); it.hasNext();) {
           ServerConnection serverConnection = (ServerConnection) it.next();
-          // logger.warn("Terminating " + serverConnection);
           serverConnection.handleTermination(timedOut);
         }
       }
@@ -586,54 +555,51 @@ public class ClientHealthMonitor {
     return result;
   }
 
-  protected boolean isAnyThreadProcessingMessage(ClientProxyMembershipID 
proxyID) {
-    boolean processingMessage = false;
-    synchronized (this._clientThreadsLock) {
-      Set serverConnections = (Set) this._clientThreads.get(proxyID);
-      if (serverConnections != null) {
-        for (Iterator it = serverConnections.iterator(); it.hasNext();) {
-          ServerConnection serverConnection = (ServerConnection) it.next();
-          if (serverConnection.isProcessingMessage()) {
-            processingMessage = true;
-            break;
-          }
-        }
+  // This will return true if the proxyID is truly idle (or if no connections 
are found), or false
+  // if there was a active connection.
+  private boolean 
prepareToTerminateIfNoConnectionIsProcessing(ClientProxyMembershipID proxyID) {
+    synchronized (proxyIdConnections) {
+      ServerConnectionCollection collection = proxyIdConnections.get(proxyID);
+      if (collection == null) {
+        return true;
+      }
+      if (collection.connectionsProcessing.get() == 0) {
+        collection.isTerminating = true;
+        return true;
+      } else {
+        return false;
       }
     }
-    return processingMessage;
   }
 
   protected void validateThreads(ClientProxyMembershipID proxyID) {
-    Set serverConnections = null;
-    synchronized (this._clientThreadsLock) {
-      serverConnections = (Set) this._clientThreads.get(proxyID);
-      if (serverConnections != null) {
-        serverConnections = new HashSet(serverConnections);
-      }
+    Set<ServerConnection> serverConnections;
+    synchronized (proxyIdConnections) {
+      ServerConnectionCollection collection = proxyIdConnections.get(proxyID);
+      serverConnections =
+          collection != null ? new HashSet<>(collection.getConnections()) : 
Collections.emptySet();
     }
     // release sync and operation on copy to fix bug 37675
-    if (serverConnections != null) {
-      for (Iterator it = serverConnections.iterator(); it.hasNext();) {
-        ServerConnection serverConnection = (ServerConnection) it.next();
-        if (serverConnection.hasBeenTimedOutOnClient()) {
-          logger.warn(LocalizedMessage.create(
-              
LocalizedStrings.ClientHealtMonitor_0_IS_BEING_TERMINATED_BECAUSE_ITS_CLIENT_TIMEOUT_OF_1_HAS_EXPIRED,
-              new Object[] {serverConnection,
-                  Integer.valueOf(serverConnection.getClientReadTimeout())}));
-          try {
-            serverConnection.handleTermination(true);
-            // Not all the code in a ServerConnection correctly
-            // handles interrupt. In particular it is possible to be doing
-            // p2p distribution and to have sent a message to one peer but
-            // to never send it to another due to interrupt.
-            // serverConnection.interruptOwner();
-          } finally {
-            // Just to be sure we clean it up.
-            // This call probably isn't needed.
-            removeConnection(proxyID, serverConnection);
-          }
+    for (ServerConnection serverConnection : serverConnections) {
+      if (serverConnection.hasBeenTimedOutOnClient()) {
+        logger.warn(LocalizedMessage.create(
+            
LocalizedStrings.ClientHealtMonitor_0_IS_BEING_TERMINATED_BECAUSE_ITS_CLIENT_TIMEOUT_OF_1_HAS_EXPIRED,
+            new Object[] {serverConnection,
+                Integer.valueOf(serverConnection.getClientReadTimeout())}));
+        try {
+          serverConnection.handleTermination(true);
+          // Not all the code in a ServerConnection correctly
+          // handles interrupt. In particular it is possible to be doing
+          // p2p distribution and to have sent a message to one peer but
+          // to never send it to another due to interrupt.
+          // serverConnection.interruptOwner();
+        } finally {
+          // Just to be sure we clean it up.
+          // This call probably isn't needed.
+          removeConnection(proxyID, serverConnection);
         }
       }
+
     }
   }
 
@@ -688,9 +654,6 @@ public class ClientHealthMonitor {
     this._cache = cache;
     this.maximumTimeBetweenPings = maximumTimeBetweenPings;
 
-    // Initialize the client threads map
-    this._clientThreads = new HashMap();
-
     this.monitorInterval = 
Long.getLong(CLIENT_HEALTH_MONITOR_INTERVAL_PROPERTY,
         DEFAULT_CLIENT_MONITOR_INTERVAL_IN_MILLIS);
     logger.debug("Setting monitorInterval to {}", this.monitorInterval);
@@ -722,6 +685,10 @@ public class ClientHealthMonitor {
     return "ClientHealthMonitor@" + 
Integer.toHexString(System.identityHashCode(this));
   }
 
+  public ServerConnectionCollection 
getProxyIdCollection(ClientProxyMembershipID proxyID) {
+    return proxyIdConnections.computeIfAbsent(proxyID, key -> new 
ServerConnectionCollection());
+  }
+
   public Map getCleanupProxyIdTable() {
     return cleanupProxyIdTable;
   }
@@ -828,8 +795,6 @@ public class ClientHealthMonitor {
           if (logger.isTraceEnabled()) {
             logger.trace("Monitoring {} client(s)", 
getClientHeartbeats().size());
           }
-          // logger.warning("Monitoring " + getClientHeartbeats().size() +
-          // " client(s).");
 
           // Get the current time
           long currentTime = System.currentTimeMillis();
@@ -863,19 +828,19 @@ public class ClientHealthMonitor {
                 // This client has been idle for too long. Determine whether
                 // any of its ServerConnection threads are currently processing
                 // a message. If so, let it go. If not, disconnect it.
-                if (isAnyThreadProcessingMessage(proxyID)) {
-                  if (logger.isDebugEnabled()) {
-                    logger.debug(
-                        "Monitoring client with member id {}. It has been {} 
ms since the latest heartbeat. This client would have been terminated but at 
least one of its threads is processing a message.",
-                        entry.getKey(), (currentTime - latestHeartbeat));
-                  }
-                } else {
+                if (prepareToTerminateIfNoConnectionIsProcessing(proxyID)) {
                   if (cleanupClientThreads(proxyID, true)) {
                     logger.warn(LocalizedMessage.create(
                         
LocalizedStrings.ClientHealthMonitor_MONITORING_CLIENT_WITH_MEMBER_ID_0_IT_HAD_BEEN_1_MS_SINCE_THE_LATEST_HEARTBEAT_MAX_INTERVAL_IS_2_TERMINATED_CLIENT,
                         new Object[] {entry.getKey(), currentTime - 
latestHeartbeat,
                             this._maximumTimeBetweenPings}));
                   }
+                } else {
+                  if (logger.isDebugEnabled()) {
+                    logger.debug(
+                        "Monitoring client with member id {}. It has been {} 
ms since the latest heartbeat. This client would have been terminated but at 
least one of its threads is processing a message.",
+                        entry.getKey(), (currentTime - latestHeartbeat));
+                  }
                 }
               } else {
                 if (logger.isTraceEnabled()) {
@@ -883,10 +848,6 @@ public class ClientHealthMonitor {
                       "Monitoring client with member id {}. It has been {} ms 
since the latest heartbeat. This client is healthy.",
                       entry.getKey(), (currentTime - latestHeartbeat));
                 }
-                // logger.warning("Monitoring client with member id " +
-                // entry.getKey() + ". It has been " + (currentTime -
-                // latestHeartbeat) + " ms since the latest heartbeat. This
-                // client is healthy.");
               }
             }
           }
diff --git 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/HandShake.java
 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/HandShake.java
index 91293f0..4b1f2c7 100644
--- 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/HandShake.java
+++ 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/HandShake.java
@@ -1683,7 +1683,7 @@ public class HandShake implements ClientHandShake {
 
     Authenticator auth = null;
     try {
-      if (AcceptorImpl.isIntegratedSecurity()) {
+      if (securityService.isIntegratedSecurity()) {
         return securityService.login(credentials);
       } else {
         Method instanceGetter = 
ClassLoadUtil.methodFromName(authenticatorMethod);
diff --git 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerConnection.java
 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerConnection.java
index 273485e..d4e5969 100644
--- 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerConnection.java
+++ 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerConnection.java
@@ -42,7 +42,6 @@ import org.apache.geode.SystemFailure;
 import org.apache.geode.cache.client.internal.AbstractOp;
 import org.apache.geode.cache.client.internal.Connection;
 import org.apache.geode.distributed.DistributedSystem;
-import 
org.apache.geode.distributed.internal.membership.InternalDistributedMember;
 import org.apache.geode.internal.Assert;
 import org.apache.geode.internal.HeapDataOutputStream;
 import org.apache.geode.internal.Version;
@@ -97,7 +96,7 @@ public abstract class ServerConnection implements Runnable {
 
   private Map commands;
 
-  private final SecurityService securityService;
+  protected final SecurityService securityService;
 
   protected final CacheServerStats stats;
 
@@ -107,6 +106,7 @@ public abstract class ServerConnection implements Runnable {
   // The key is the size of each ByteBuffer. The value is a queue of byte 
buffers all of that size.
   private static final ConcurrentHashMap<Integer, 
LinkedBlockingQueue<ByteBuffer>> commBufferMap =
       new ConcurrentHashMap<>(4, 0.75f, 1);
+  private ServerConnectionCollection serverConnectionCollection;
 
   public static ByteBuffer allocateCommBuffer(int size, Socket sock) {
     // I expect that size will almost always be the same value
@@ -194,17 +194,8 @@ public abstract class ServerConnection implements Runnable 
{
    */
   private int latestBatchIdReplied = -1;
 
-  /*
-   * Uniquely identifying the client's Distributed System
-   *
-   *
-   * private String membershipId;
-   *
-   *
-   * Uniquely identifying the client's ConnectionProxy object
-   *
-   *
-   * private String proxyID ;
+  /**
+   * Client identity from handshake
    */
   ClientProxyMembershipID proxyId;
 
@@ -333,7 +324,8 @@ public abstract class ServerConnection implements Runnable {
     synchronized (this.handShakeMonitor) {
       if (this.handshake == null) {
         // synchronized (getCleanupTable()) {
-        boolean readHandShake = ServerHandShakeProcessor.readHandShake(this, 
getSecurityService());
+        boolean readHandShake =
+            ServerHandShakeProcessor.readHandShake(this, getSecurityService(), 
acceptor);
         if (readHandShake) {
           if (this.handshake.isOK()) {
             try {
@@ -469,11 +461,7 @@ public abstract class ServerConnection implements Runnable 
{
     return acceptor.getClientHealthMonitor().getCleanupProxyIdTable();
   }
 
-  private ClientHealthMonitor getClientHealthMonitor() {
-    return acceptor.getClientHealthMonitor();
-  }
-
-  private boolean processHandShake() {
+  protected boolean processHandShake() {
     boolean result = false;
     boolean clientJoined = false;
     boolean registerClient = false;
@@ -559,8 +547,6 @@ public abstract class ServerConnection implements Runnable {
           numRefs = new Counter();
           numRefs.incr();
           getCleanupProxyIdTable().put(this.proxyId, numRefs);
-          InternalDistributedMember idm =
-              (InternalDistributedMember) this.proxyId.getDistributedMember();
         }
         this.incedCleanupProxyIdTableRef = true;
       }
@@ -583,7 +569,7 @@ public abstract class ServerConnection implements Runnable {
         chm.registerClient(this.proxyId);
       }
       // hitesh:it will add client connection in set
-      chm.addConnection(this.proxyId, this);
+      serverConnectionCollection = chm.addConnection(this.proxyId, this);
       this.acceptor.getConnectionListener().connectionOpened(registerClient, 
communicationMode);
       // Hitesh: add user creds in map for single user case.
     } // finally
@@ -725,8 +711,22 @@ public abstract class ServerConnection implements Runnable 
{
   }
 
   protected void doNormalMsg() {
+    if (serverConnectionCollection == null) {
+      // return here if we haven't successfully completed handshake
+      logger.warn("Continued processing ServerConnection after handshake 
failed");
+      this.processMessages = false;
+      return;
+    }
     Message msg = null;
     msg = BaseCommand.readRequest(this);
+    synchronized (serverConnectionCollection) {
+      if (serverConnectionCollection.isTerminating) {
+        // Client is being disconnected, don't try to process message.
+        this.processMessages = false;
+        return;
+      }
+      serverConnectionCollection.connectionsProcessing.incrementAndGet();
+    }
     ThreadState threadState = null;
     try {
       if (msg != null) {
@@ -775,7 +775,7 @@ public abstract class ServerConnection implements Runnable {
 
         // if a subject exists for this uniqueId, binds the subject to this 
thread so that we can do
         // authorization later
-        if (AcceptorImpl.isIntegratedSecurity()
+        if (securityService.isIntegratedSecurity()
             && !isInternalMessage(this.requestMsg, 
allowInternalMessagesWithoutCredentials)
             && !this.communicationMode.isWAN()) {
           long uniqueId = getUniqueId();
@@ -799,13 +799,13 @@ public abstract class ServerConnection implements 
Runnable {
     } finally {
       // Keep track of the fact that a message is no longer being
       // processed.
+      serverConnectionCollection.connectionsProcessing.decrementAndGet();
       setNotProcessingMessage();
       clearRequestMsg();
       if (threadState != null) {
         threadState.clear();
       }
     }
-
   }
 
   private final Object terminationLock = new Object();
@@ -874,8 +874,6 @@ public abstract class ServerConnection implements Runnable {
             getCleanupProxyIdTable().remove(this.proxyId);
             // here we can remove entry multiuser map for client
             proxyIdVsClientUserAuths.remove(this.proxyId);
-            InternalDistributedMember idm =
-                (InternalDistributedMember) 
this.proxyId.getDistributedMember();
           }
         }
       }
@@ -937,7 +935,7 @@ public abstract class ServerConnection implements Runnable {
     return retCua;
   }
 
-  private void initializeCommands() {
+  protected void initializeCommands() {
     // The commands are cached here, but are just referencing the ones
     // stored in the CommandInitializer
     this.commands = CommandInitializer.getCommands(this);
@@ -1499,6 +1497,7 @@ public abstract class ServerConnection implements 
Runnable {
       logger.debug("{}: Closed connection", this.name);
     }
     releaseCommBuffer();
+    processMessages = false;
     return true;
   }
 
@@ -1762,7 +1761,7 @@ public abstract class ServerConnection implements 
Runnable {
       return null;
     }
 
-    if (AcceptorImpl.isIntegratedSecurity()) {
+    if (securityService.isIntegratedSecurity()) {
       return null;
     }
 
@@ -1796,7 +1795,7 @@ public abstract class ServerConnection implements 
Runnable {
       return null;
     }
 
-    if (AcceptorImpl.isIntegratedSecurity()) {
+    if (securityService.isIntegratedSecurity()) {
       return null;
     }
 
diff --git 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerConnectionCollection.java
 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerConnectionCollection.java
new file mode 100644
index 0000000..670b5a1
--- /dev/null
+++ 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerConnectionCollection.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more 
contributor license
+ * agreements. See the NOTICE file distributed with this work for additional 
information regarding
+ * copyright ownership. The ASF licenses this file to You under the Apache 
License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the 
License. You may obtain a
+ * copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software 
distributed under the License
+ * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 
KIND, either express
+ * or implied. See the License for the specific language governing permissions 
and limitations under
+ * the License.
+ */
+package org.apache.geode.internal.cache.tier.sockets;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+
+// This is used to form of group of connections for a particular client. Note 
that these objects are
+// managed by the ClientHealthMonitor, which also manages the synchronization 
of them.
+public class ServerConnectionCollection {
+  private Set<ServerConnection> connectionSet = new 
HashSet<ServerConnection>();
+
+  // Number of connections currently processing messages for this client
+  final AtomicInteger connectionsProcessing = new AtomicInteger();
+
+  // Indicates that the server is soon to be or already in the process of 
terminating connections in
+  // this collection.
+  volatile boolean isTerminating = false;
+
+  public void addConnection(ServerConnection connection) {
+    connectionSet.add(connection);
+  }
+
+  public Set<ServerConnection> getConnections() {
+    return connectionSet;
+  }
+
+  public void removeConnection(ServerConnection connection) {
+    connectionSet.remove(connection);
+  }
+}
diff --git 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerHandShakeProcessor.java
 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerHandShakeProcessor.java
index c4265cd..e292813 100755
--- 
a/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerHandShakeProcessor.java
+++ 
b/geode-core/src/main/java/org/apache/geode/internal/cache/tier/sockets/ServerHandShakeProcessor.java
@@ -79,8 +79,8 @@ public class ServerHandShakeProcessor {
     currentServerVersion = Version.fromOrdinalOrCurrent(ver);
   }
 
-  public static boolean readHandShake(ServerConnection connection,
-      SecurityService securityService) {
+  public static boolean readHandShake(ServerConnection connection, 
SecurityService securityService,
+      AcceptorImpl acceptorImpl) {
     boolean validHandShake = false;
     Version clientVersion = null;
     try {
@@ -123,7 +123,7 @@ public class ServerHandShakeProcessor {
 
       // Read the appropriate handshake
       if (clientVersion.compareTo(Version.GFE_57) >= 0) {
-        validHandShake = readGFEHandshake(connection, clientVersion, 
securityService);
+        validHandShake = readGFEHandshake(connection, clientVersion, 
securityService, acceptorImpl);
       } else {
         connection.refuseHandshake(
             "Unsupported version " + clientVersion + "Server's current version 
" + Acceptor.VERSION,
@@ -200,7 +200,7 @@ public class ServerHandShakeProcessor {
   }
 
   private static boolean readGFEHandshake(ServerConnection connection, Version 
clientVersion,
-      SecurityService securityService) {
+      SecurityService securityService, AcceptorImpl acceptorImpl) {
     int handShakeTimeout = connection.getHandShakeTimeout();
     InternalLogWriter securityLogWriter = connection.getSecurityLogWriter();
     try {
diff --git 
a/geode-core/src/test/java/org/apache/geode/internal/cache/tier/sockets/ServerConnectionTest.java
 
b/geode-core/src/test/java/org/apache/geode/internal/cache/tier/sockets/ServerConnectionTest.java
index dbda3d7..b7f0e7b 100644
--- 
a/geode-core/src/test/java/org/apache/geode/internal/cache/tier/sockets/ServerConnectionTest.java
+++ 
b/geode-core/src/test/java/org/apache/geode/internal/cache/tier/sockets/ServerConnectionTest.java
@@ -20,6 +20,7 @@ package org.apache.geode.internal.cache.tier.sockets;
 import static 
org.apache.geode.internal.i18n.LocalizedStrings.HandShake_NO_SECURITY_CREDENTIALS_ARE_PROVIDED;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Mockito.mock;
@@ -29,6 +30,10 @@ import java.io.IOException;
 import java.net.InetAddress;
 import java.net.Socket;
 import java.util.Locale;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 
 import org.junit.Before;
 import org.junit.Rule;
@@ -38,10 +43,16 @@ import org.mockito.InjectMocks;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
+import org.apache.geode.cache.client.internal.Connection;
+import 
org.apache.geode.distributed.internal.membership.InternalDistributedMember;
 import org.apache.geode.i18n.StringId;
 import org.apache.geode.internal.Version;
 import org.apache.geode.internal.cache.InternalCache;
+import org.apache.geode.internal.cache.TXManagerImpl;
+import org.apache.geode.internal.cache.tier.Acceptor;
+import org.apache.geode.internal.cache.tier.CachedRegionHelper;
 import org.apache.geode.internal.cache.tier.CommunicationMode;
+import org.apache.geode.internal.cache.tier.MessageType;
 import org.apache.geode.internal.security.SecurityService;
 import org.apache.geode.security.AuthenticationRequiredException;
 import org.apache.geode.test.junit.categories.UnitTest;
@@ -70,21 +81,29 @@ public class ServerConnectionTest {
   @InjectMocks
   private ServerConnection serverConnection;
 
+  private AcceptorImpl acceptor;
+  private Socket socket;
+  private InternalCache cache;
+  private SecurityService securityService;
+  private CacheServerStats stats;
+
   @Before
   public void setUp() throws IOException {
-    AcceptorImpl acceptor = mock(AcceptorImpl.class);
+    acceptor = mock(AcceptorImpl.class);
 
     InetAddress inetAddress = mock(InetAddress.class);
     when(inetAddress.getHostAddress()).thenReturn("localhost");
 
-    Socket socket = mock(Socket.class);
+    socket = mock(Socket.class);
     when(socket.getInetAddress()).thenReturn(inetAddress);
 
-    InternalCache cache = mock(InternalCache.class);
-    SecurityService securityService = mock(SecurityService.class);
+    cache = mock(InternalCache.class);
+    securityService = mock(SecurityService.class);
+
+    stats = mock(CacheServerStats.class);
 
     serverConnection =
-        new ServerConnectionFactory().makeServerConnection(socket, cache, 
null, null, 0, 0, null,
+        new ServerConnectionFactory().makeServerConnection(socket, cache, 
null, stats, 0, 0, null,
             CommunicationMode.PrimaryServerToClient.getModeNumber(), acceptor, 
securityService);
     MockitoAnnotations.initMocks(this);
   }
@@ -139,4 +158,114 @@ public class ServerConnectionTest {
         
.hasMessage(HandShake_NO_SECURITY_CREDENTIALS_ARE_PROVIDED.getRawText());
   }
 
+  class TestMessage extends Message {
+    private final Lock lock = new ReentrantLock();
+    private final Condition testGate = lock.newCondition();
+    private boolean signalled = false;
+
+    public TestMessage() {
+      super(3, Version.CURRENT);
+      messageType = MessageType.REQUEST;
+      securePart = new Part();
+    }
+
+    @Override
+    public void recv() throws IOException {
+      try {
+        lock.lock();
+        testGate.await(10, TimeUnit.SECONDS);
+      } catch (InterruptedException e) {
+        throw new RuntimeException(e);
+      } finally {
+        lock.unlock();
+        if (!signalled) {
+          fail("Message never received continueProcessing call");
+        }
+      }
+    }
+
+    public void continueProcessing() {
+      lock.lock();
+      testGate.signal();
+      signalled = true;
+      lock.unlock();
+    }
+  }
+
+  class TestServerConnection extends LegacyServerConnection {
+
+    private TestMessage testMessage;
+
+    /**
+     * Creates a new <code>ServerConnection</code> that processes messages 
received from an edge
+     * client over a given <code>Socket</code>.
+     */
+    public TestServerConnection(Socket socket, InternalCache internalCache,
+        CachedRegionHelper helper, CacheServerStats stats, int hsTimeout, int 
socketBufferSize,
+        String communicationModeStr, byte communicationMode, Acceptor acceptor,
+        SecurityService securityService) {
+      super(socket, internalCache, helper, stats, hsTimeout, socketBufferSize, 
communicationModeStr,
+          communicationMode, acceptor, securityService);
+
+      setClientDisconnectCleanly(); // Not clear where this is supposed to be 
set in the timeout
+                                    // path
+    }
+
+    @Override
+    protected void doHandshake() {
+      ClientProxyMembershipID proxyID = mock(ClientProxyMembershipID.class);
+      
when(proxyID.getDistributedMember()).thenReturn(mock(InternalDistributedMember.class));
+      HandShake handShake = mock(HandShake.class);
+      when(handShake.getMembership()).thenReturn(proxyID);
+      when(handShake.getVersion()).thenReturn(Version.CURRENT);
+
+      setHandshake(handShake);
+      setProxyId(proxyID);
+
+      processHandShake();
+      initializeCommands();
+
+      setFakeRequest();
+
+      long fakeId = -1;
+      MessageIdExtractor extractor = mock(MessageIdExtractor.class);
+      when(extractor.getUniqueIdFromMessage(getRequestMessage(), handShake,
+          Connection.DEFAULT_CONNECTION_ID)).thenReturn(fakeId);
+      setMessageIdExtractor(extractor);
+    }
+
+    @Override
+    void handleTermination(boolean timedOut) {
+      super.handleTermination(timedOut);
+      testMessage.continueProcessing();
+    }
+
+    private void setFakeRequest() {
+      testMessage = new TestMessage();
+      setRequestMsg(testMessage);
+    }
+  }
+
+  /**
+   * This test sets up a TestConnection which will register with the 
ClientHealthMonitor and then
+   * block waiting to receive a fake message. This message will arrive just 
after the health monitor
+   * times out this connection and kills it. The test then makes sure that the 
connection correctly
+   * handles the terminated state and exits.
+   */
+  @Test
+  public void terminatingConnectionHandlesNewRequestsGracefully() throws 
Exception {
+    
when(cache.getCacheTransactionManager()).thenReturn(mock(TXManagerImpl.class));
+    ClientHealthMonitor.createInstance(cache, 100, 
mock(CacheClientNotifierStats.class));
+    ClientHealthMonitor clientHealthMonitor = 
ClientHealthMonitor.getInstance();
+    when(acceptor.getClientHealthMonitor()).thenReturn(clientHealthMonitor);
+    
when(acceptor.getConnectionListener()).thenReturn(mock(ConnectionListener.class));
+    when(securityService.isIntegratedSecurity()).thenReturn(true);
+
+    TestServerConnection testServerConnection =
+        new TestServerConnection(socket, cache, 
mock(CachedRegionHelper.class), stats, 0, 0, null,
+            CommunicationMode.PrimaryServerToClient.getModeNumber(), acceptor, 
securityService);
+    MockitoAnnotations.initMocks(this);
+
+    testServerConnection.run();
+  }
 }

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to