Test for SSHD-721. Many concurrent port forward connections deadlock.

Project: http://git-wip-us.apache.org/repos/asf/mina-sshd/repo
Commit: http://git-wip-us.apache.org/repos/asf/mina-sshd/commit/5abc47e8
Tree: http://git-wip-us.apache.org/repos/asf/mina-sshd/tree/5abc47e8
Diff: http://git-wip-us.apache.org/repos/asf/mina-sshd/diff/5abc47e8

Branch: refs/heads/master
Commit: 5abc47e8cc4655c8d58b6d8b41c291c33c2e13db
Parents: 83925cb
Author: bkuker <bku...@martellotech.com>
Authored: Fri Apr 13 16:19:46 2018 -0400
Committer: bkuker <bku...@martellotech.com>
Committed: Mon Apr 16 08:20:54 2018 -0400

----------------------------------------------------------------------
 .../forward/ConcurrentConnectionTest.java       | 281 +++++++++++++++++++
 1 file changed, 281 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/5abc47e8/sshd-core/src/test/java/org/apache/sshd/common/forward/ConcurrentConnectionTest.java
----------------------------------------------------------------------
diff --git 
a/sshd-core/src/test/java/org/apache/sshd/common/forward/ConcurrentConnectionTest.java
 
b/sshd-core/src/test/java/org/apache/sshd/common/forward/ConcurrentConnectionTest.java
new file mode 100644
index 0000000..38434df
--- /dev/null
+++ 
b/sshd-core/src/test/java/org/apache/sshd/common/forward/ConcurrentConnectionTest.java
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sshd.common.forward;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.net.SocketException;
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.sshd.client.SshClient;
+import org.apache.sshd.client.session.ClientSession;
+import org.apache.sshd.common.session.Session;
+import org.apache.sshd.common.util.net.SshdSocketAddress;
+import org.apache.sshd.server.SshServer;
+import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
+import org.apache.sshd.server.forward.ForwardingFilter;
+import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider;
+import org.apache.sshd.util.test.BaseTestSupport;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Port forwarding test multiple clients connecting at once.
+ */
+public class ConcurrentConnectionTest extends BaseTestSupport {
+       private static final byte[] PAYLOAD_TO_SERVER = "To Server -> To Server 
-> To Server".getBytes();
+       private static final byte[] PAYLOAD_TO_CLIENT = "<- To Client <- To 
Client <-".getBytes();
+       private final static Logger LOG = 
LoggerFactory.getLogger(ConcurrentConnectionTest.class);
+
+       // These are the critical test parameters.
+       // When the number of clients is greater than or equal to the number of 
IO
+       // Workers, the server deadlocks
+       private static final int SSHD_NIO_WORKERS = 8;
+       private static final int PORT_FORWARD_CLIENT_COUNT = 12;
+
+       // For very large numbers of clients and small numbers of threads this 
may
+       // need to be increased
+       private static final int TIMEOUT = (int) TimeUnit.SECONDS.toMillis(10L);
+
+       // Test Server State
+       private int testServerPort;
+       private ServerSocket testServerSock;
+       private Thread testServerThread;
+
+       // SSHD Server State
+       private static int sshServerPort;
+       private static SshServer server;
+
+       // SSH Client State
+       private ClientSession session;
+
+       /*
+        * Start a server to forward to.
+        *
+        * Reads PAYLOAD_TO_SERVER from client and then sends PAYLOAD_TO_CLIENT 
to
+        * client. This server emulates a web server, closely enough for thie 
test
+        */
+       @Before
+       public void startTestServer() throws Exception {
+               final AtomicInteger activeServers = new AtomicInteger(0);
+               testServerThread = new Thread(() -> {
+                       try {
+                               testServerSock = new ServerSocket(0);
+                               testServerPort = testServerSock.getLocalPort();
+                               LOG.debug("Listening on {}", testServerPort);
+                               while (true) {
+                                       final Socket s = 
testServerSock.accept();
+                                       LOG.debug("Got connection");
+                                       final Thread server = new Thread(() -> {
+                                               try {
+                                                       LOG.debug("Active 
Servers: {}", activeServers.incrementAndGet());
+                                                       final byte[] buf = new 
byte[PAYLOAD_TO_SERVER.length];
+                                                       final long r = 
s.getInputStream().read(buf);
+                                                       LOG.debug("Read {} 
payload from client", r);
+                                                       
s.getOutputStream().write(PAYLOAD_TO_CLIENT);
+                                                       LOG.debug("Wrote 
payload to client");
+                                                       s.close();
+                                                       LOG.debug("Active 
Servers: {}", activeServers.decrementAndGet());
+                                               } catch (final Throwable t) {
+                                                       LOG.error("Error", t);
+                                               }
+                                       });
+                                       server.setDaemon(true);
+                                       server.setName("Server " + s.getPort());
+                                       server.start();
+                               }
+                       } catch (final SocketException e) {
+                               LOG.debug("Shutting down test server");
+                       } catch (final Throwable t) {
+                               LOG.error("Error", t);
+                       }
+               });
+               testServerThread.setDaemon(true);
+               testServerThread.setName("Server Acceptor");
+               testServerThread.start();
+               Thread.sleep(100);
+       }
+
+       @After
+       public void stopTestServer() throws Exception {
+               testServerSock.close();
+               testServerThread.interrupt();
+       }
+
+       @BeforeClass
+       public static void startSshServer() throws IOException {
+               LOG.debug("Starting SSHD...");
+               server = SshServer.setUpDefaultServer();
+               server.setPasswordAuthenticator((u, p, s) -> true);
+               server.setForwardingFilter(AcceptAllForwardingFilter.INSTANCE);
+               server.setKeyPairProvider(new SimpleGeneratorHostKeyProvider());
+               server.setNioWorkers(SSHD_NIO_WORKERS);
+               server.setForwardingFilter(new ForwardingFilter() {
+
+                       @Override
+                       public boolean canListen(SshdSocketAddress address, 
Session session) {
+                               // TODO Auto-generated method stub
+                               return true;
+                       }
+
+                       @Override
+                       public boolean canConnect(Type type, SshdSocketAddress 
address, Session session) {
+                               // TODO Auto-generated method stub
+                               return false;
+                       }
+
+                       @Override
+                       public boolean canForwardX11(Session session, String 
requestType) {
+                               // TODO Auto-generated method stub
+                               return false;
+                       }
+
+                       @Override
+                       public boolean canForwardAgent(Session session, String 
requestType) {
+                               // TODO Auto-generated method stub
+                               return false;
+                       }
+               });
+               server.start();
+               sshServerPort = server.getPort();
+               LOG.debug("SSHD Running on port {}", server.getPort());
+       }
+
+       @AfterClass
+       public static void stopServer() throws IOException {
+               if (!server.close(true).await(TIMEOUT)) {
+                       LOG.warn("Failed to close server within {} sec.", 
TimeUnit.MILLISECONDS.toSeconds(TIMEOUT));
+               }
+       }
+
+       @Before
+       public void createClient() throws IOException {
+               final SshClient client = SshClient.setUpDefaultClient();
+               client.setForwardingFilter(AcceptAllForwardingFilter.INSTANCE);
+               client.start();
+               LOG.debug("Connecting...");
+               session = client.connect("user", TEST_LOCALHOST, 
sshServerPort).verify(TIMEOUT).getSession();
+               LOG.debug("Authenticating...");
+               session.addPasswordIdentity("foo");
+               session.auth().verify(TIMEOUT);
+               LOG.debug("Authenticated");
+       }
+
+       @After
+       public void stopClient() throws Exception {
+               LOG.debug("Disconnecting Client");
+               try {
+                       assertTrue("Failed to close session", 
session.close(true).await(TIMEOUT));
+               } finally {
+                       session = null;
+               }
+       }
+
+       @Test
+       /*
+        * Run PORT_FORWARD_CLIENT_COUNT simultaneous server threads.
+        * 
+        * Emulates a web browser making a number of simultaneous requests on
+        * different connections to the same server HTTP specifies no more than 
two,
+        * but most modern browsers do 6 or more.
+        */
+       public void testConcurrentConnectionsToPortForward() throws Exception {
+               final SshdSocketAddress remote = new 
SshdSocketAddress(TEST_LOCALHOST, 0);
+               final SshdSocketAddress local = new 
SshdSocketAddress(TEST_LOCALHOST, testServerPort);
+               final SshdSocketAddress bound = 
session.startRemotePortForwarding(remote, local);
+               final int forwardedPort = bound.getPort();
+
+               final CyclicBarrier b = new 
CyclicBarrier(PORT_FORWARD_CLIENT_COUNT, () -> {
+                       LOG.debug("And away we go.");
+               });
+
+               final AtomicInteger success = new AtomicInteger(0);
+               final AtomicInteger fail = new AtomicInteger(0);
+               final long[] bytesRead = new long[PORT_FORWARD_CLIENT_COUNT];
+
+               for (int i = 0; i < PORT_FORWARD_CLIENT_COUNT; i++) {
+                       final long wait = 100 * i;
+                       final int n = i;
+                       final Thread t = new Thread(() -> {
+                               try {
+                                       bytesRead[n] = 
makeClientRequest(forwardedPort, b, wait);
+                                       LOG.debug("Complete, received full 
payload from server.");
+                                       success.incrementAndGet();
+                               } catch (final Exception e) {
+                                       fail.incrementAndGet();
+                                       LOG.error("Error in client code", e);
+                               }
+                       });
+                       t.setName("Client " + i);
+                       t.setDaemon(true);
+                       t.start();
+               }
+
+               while (true) {
+                       if (success.get() + fail.get() == 
PORT_FORWARD_CLIENT_COUNT) {
+                               break;
+                       }
+                       Thread.sleep(100);
+               }
+
+               for (int i = 0; i < PORT_FORWARD_CLIENT_COUNT; i++) {
+                       assertEquals("Mismatched data length read from server 
for client " + i, PAYLOAD_TO_CLIENT.length,
+                                       bytesRead[i]);
+               }
+
+               assertEquals("Not all clients succeeded", 
PORT_FORWARD_CLIENT_COUNT, success.get());
+       }
+
+       /**
+        * Send PAYLOAD_TO_SERVER to the server, then read PAYLOAD_TO_CLIENT 
from
+        * server. Emulates a web browser making a request
+        */
+       private long makeClientRequest(final int serverPort, final 
CyclicBarrier barrier, final long wait)
+                       throws Exception {
+               outputDebugMessage("readInLoop(port=%d)", serverPort);
+
+               final Socket s = new Socket();
+               s.setSoTimeout(TIMEOUT);
+
+               barrier.await();
+
+               s.connect(new InetSocketAddress(TEST_LOCALHOST, serverPort));
+
+               s.getOutputStream().write(PAYLOAD_TO_SERVER);
+
+               final byte[] buf = new byte[PAYLOAD_TO_CLIENT.length];
+               final long r = s.getInputStream().read(buf);
+               LOG.debug("Read {} payload from server", r);
+
+               assertEquals("Mismatched data length", 
PAYLOAD_TO_CLIENT.length, r);
+               s.close();
+
+               return r;
+       }
+
+}
\ No newline at end of file

Reply via email to