Github user ivmaykov commented on a diff in the pull request: https://github.com/apache/zookeeper/pull/679#discussion_r233668370 --- Diff: zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java --- @@ -17,156 +17,644 @@ */ package org.apache.zookeeper.server.quorum; +import java.io.BufferedInputStream; +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Random; + +import javax.net.ssl.HandshakeCompletedEvent; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.SSLSocket; + import org.apache.zookeeper.PortAssignment; import org.apache.zookeeper.client.ZKClientConfig; +import org.apache.zookeeper.common.BaseX509ParameterizedTestCase; import org.apache.zookeeper.common.ClientX509Util; -import org.apache.zookeeper.common.Time; +import org.apache.zookeeper.common.KeyStoreFileType; +import org.apache.zookeeper.common.X509Exception; +import org.apache.zookeeper.common.X509KeyType; +import org.apache.zookeeper.common.X509TestContext; import org.apache.zookeeper.common.X509Util; import org.apache.zookeeper.server.ServerCnxnFactory; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; -import javax.net.ssl.HandshakeCompletedEvent; -import javax.net.ssl.HandshakeCompletedListener; -import javax.net.ssl.SSLSocket; -import java.io.IOException; -import java.net.ConnectException; -import java.net.InetSocketAddress; -import java.net.Socket; - -import static org.hamcrest.CoreMatchers.equalTo; -import static org.junit.Assert.assertThat; +@RunWith(Parameterized.class) +public class UnifiedServerSocketTest extends BaseX509ParameterizedTestCase { -public class UnifiedServerSocketTest { + @Parameterized.Parameters + public static Collection<Object[]> params() { + ArrayList<Object[]> result = new ArrayList<>(); + int paramIndex = 0; + for (X509KeyType caKeyType : X509KeyType.values()) { + for (X509KeyType certKeyType : X509KeyType.values()) { + for (Boolean hostnameVerification : new Boolean[] { true, false }) { + result.add(new Object[]{ + caKeyType, + certKeyType, + hostnameVerification, + paramIndex++ + }); + } + } + } + return result; + } private static final int MAX_RETRIES = 5; private static final int TIMEOUT = 1000; + private static final byte[] DATA_TO_CLIENT = "hello client".getBytes(); + private static final byte[] DATA_FROM_CLIENT = "hello server".getBytes(); private X509Util x509Util; private int port; - private volatile boolean handshakeCompleted; + private InetSocketAddress localServerAddress; + private final Object handshakeCompletedLock = new Object(); + // access only inside synchronized(handshakeCompletedLock) { ... } blocks + private boolean handshakeCompleted = false; + + public UnifiedServerSocketTest( + final X509KeyType caKeyType, + final X509KeyType certKeyType, + final Boolean hostnameVerification, + final Integer paramIndex) { + super(paramIndex, () -> { + try { + return X509TestContext.newBuilder() + .setTempDir(tempDir) + .setKeyStoreKeyType(certKeyType) + .setTrustStoreKeyType(caKeyType) + .setHostnameVerification(hostnameVerification) + .build(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } @Before public void setUp() throws Exception { - handshakeCompleted = false; - port = PortAssignment.unique(); + localServerAddress = new InetSocketAddress("localhost", port); - String testDataPath = System.getProperty("test.data.dir", "build/test/data"); System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY, "org.apache.zookeeper.server.NettyServerCnxnFactory"); System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, "org.apache.zookeeper.ClientCnxnSocketNetty"); System.setProperty(ZKClientConfig.SECURE_CLIENT, "true"); x509Util = new ClientX509Util(); - System.setProperty(x509Util.getSslKeystoreLocationProperty(), testDataPath + "/ssl/testKeyStore.jks"); - System.setProperty(x509Util.getSslKeystorePasswdProperty(), "testpass"); - System.setProperty(x509Util.getSslTruststoreLocationProperty(), testDataPath + "/ssl/testTrustStore.jks"); - System.setProperty(x509Util.getSslTruststorePasswdProperty(), "testpass"); - System.setProperty(x509Util.getSslHostnameVerificationEnabledProperty(), "false"); + x509TestContext.setSystemProperties(x509Util, KeyStoreFileType.JKS, KeyStoreFileType.JKS); } - @Test - public void testConnectWithSSL() throws Exception { - class ServerThread extends Thread { - public void run() { - try { - Socket unifiedSocket = new UnifiedServerSocket(x509Util, port).accept(); - ((SSLSocket)unifiedSocket).getSession(); // block until handshake completes - } catch (IOException e) { - e.printStackTrace(); + private static void forceClose(java.io.Closeable s) { + if (s == null) { + return; + } + try { + s.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + private static final class UnifiedServerThread extends Thread { + private final byte[] dataToClient; + private List<byte[]> dataFromClients; + private List<Thread> workerThreads; + private UnifiedServerSocket serverSocket; + + UnifiedServerThread(X509Util x509Util, + InetSocketAddress bindAddress, + boolean allowInsecureConnection, + byte[] dataToClient) throws IOException { + this.dataToClient = dataToClient; + dataFromClients = new ArrayList<>(); + workerThreads = new ArrayList<>(); + serverSocket = new UnifiedServerSocket(x509Util, allowInsecureConnection); + serverSocket.bind(bindAddress); + } + + @Override + public void run() { + try { + Random rnd = new Random(); + while (true) { + final Socket unifiedSocket = serverSocket.accept(); + final boolean tcpNoDelay = rnd.nextBoolean(); + unifiedSocket.setTcpNoDelay(tcpNoDelay); + unifiedSocket.setSoTimeout(TIMEOUT); + final boolean keepAlive = rnd.nextBoolean(); + unifiedSocket.setKeepAlive(keepAlive); + // Note: getting the input stream should not block the thread or trigger mode detection. + BufferedInputStream bis = new BufferedInputStream(unifiedSocket.getInputStream()); + Thread t = new Thread(new Runnable() { + @Override + public void run() { + try { + byte[] buf = new byte[1024]; + int bytesRead = unifiedSocket.getInputStream().read(buf, 0, 1024); + // Make sure the settings applied above before the socket was potentially upgraded to + // TLS still apply. + Assert.assertEquals(tcpNoDelay, unifiedSocket.getTcpNoDelay()); + Assert.assertEquals(TIMEOUT, unifiedSocket.getSoTimeout()); + Assert.assertEquals(keepAlive, unifiedSocket.getKeepAlive()); + if (bytesRead > 0) { + byte[] dataFromClient = new byte[bytesRead]; + System.arraycopy(buf, 0, dataFromClient, 0, bytesRead); + synchronized (dataFromClients) { + dataFromClients.add(dataFromClient); + } + } + unifiedSocket.getOutputStream().write(dataToClient); + unifiedSocket.getOutputStream().flush(); + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException(e); + } finally { + forceClose(unifiedSocket); + } + } + }); + workerThreads.add(t); + t.start(); } + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException(e); + } finally { + forceClose(serverSocket); } } - ServerThread serverThread = new ServerThread(); - serverThread.start(); + public void shutdown(long millis) throws InterruptedException { + forceClose(serverSocket); // this should break the run() loop + for (Thread t : workerThreads) { + t.join(millis); + } + this.join(millis); + } + + synchronized byte[] getDataFromClient(int index) { + return dataFromClients.get(index); + } + } + + private SSLSocket connectWithSSL() throws IOException, X509Exception, InterruptedException { SSLSocket sslSocket = null; int retries = 0; while (retries < MAX_RETRIES) { try { sslSocket = x509Util.createSSLSocket(); + sslSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() { + @Override + public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) { + synchronized (handshakeCompletedLock) { + handshakeCompleted = true; + handshakeCompletedLock.notifyAll(); + } + } + }); sslSocket.setSoTimeout(TIMEOUT); - sslSocket.connect(new InetSocketAddress(port), TIMEOUT); + sslSocket.connect(localServerAddress, TIMEOUT); break; } catch (ConnectException connectException) { connectException.printStackTrace(); + forceClose(sslSocket); + sslSocket = null; Thread.sleep(TIMEOUT); } retries++; } - sslSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() { - @Override - public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) { - completeHandshake(); + Assert.assertNotNull("Failed to connect to server with SSL", sslSocket); + return sslSocket; + } + + private Socket connectWithoutSSL() throws IOException, InterruptedException { + Socket socket = null; + int retries = 0; + while (retries < MAX_RETRIES) { + try { + socket = new Socket(); + socket.setSoTimeout(TIMEOUT); + socket.connect(localServerAddress, TIMEOUT); + break; + } catch (ConnectException connectException) { + connectException.printStackTrace(); + forceClose(socket); + socket = null; + Thread.sleep(TIMEOUT); } - }); - sslSocket.startHandshake(); + retries++; + } + Assert.assertNotNull("Failed to connect to server without SSL", socket); + return socket; + } + + // In the tests below, a "Strict" server means a UnifiedServerSocket that + // does not allow plaintext connections (in other words, it's SSL-only). + // A "Non Strict" server means a UnifiedServerSocket that allows both + // plaintext and SSL incoming connections. + + /** + * Attempting to connect to a SSL-or-plaintext server with SSL should work. + */ + @Test + public void testConnectWithSSLToNonStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); - serverThread.join(TIMEOUT); + Socket sslSocket = connectWithSSL(); + sslSocket.getOutputStream().write(DATA_FROM_CLIENT); + sslSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); - long start = Time.currentElapsedTime(); - while (Time.currentElapsedTime() < start + TIMEOUT) { - if (handshakeCompleted) { - return; + serverThread.shutdown(TIMEOUT); + forceClose(sslSocket); + + synchronized (handshakeCompletedLock) { + if (!handshakeCompleted) { + handshakeCompletedLock.wait(TIMEOUT); } + Assert.assertTrue(handshakeCompleted); } + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } + + /** + * Attempting to connect to a SSL-only server with SSL should work. + */ + @Test + public void testConnectWithSSLToStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); + serverThread.start(); + + Socket sslSocket = connectWithSSL(); + sslSocket.getOutputStream().write(DATA_FROM_CLIENT); + sslSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); - Assert.fail("failed to complete handshake"); + serverThread.shutdown(TIMEOUT); + forceClose(sslSocket); + + synchronized (handshakeCompletedLock) { + if (!handshakeCompleted) { + handshakeCompletedLock.wait(TIMEOUT); + } + Assert.assertTrue(handshakeCompleted); + } + + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); } - private void completeHandshake() { - handshakeCompleted = true; + /** + * Attempting to connect to a SSL-or-plaintext server without SSL should work. + */ + @Test + public void testConnectWithoutSSLToNonStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + Socket socket = connectWithoutSSL(); + socket.getOutputStream().write(DATA_FROM_CLIENT); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + serverThread.shutdown(TIMEOUT); + forceClose(socket); + + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); } + /** + * Attempting to connect to a SSL-or-plaintext server without SSL with a + * small initial data write should work. This makes sure that sending + * less than 5 bytes does not break the logic in the server's initial 5 + * byte read. + */ @Test - public void testConnectWithoutSSL() throws Exception { - final byte[] testData = "hello there".getBytes(); - final String[] dataReadFromClient = {null}; + public void testConnectWithoutSSLToNonStrictServerPartialWrite() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + Socket socket = connectWithoutSSL(); + // Write only 2 bytes of the message, wait a bit, then write the rest. + // This makes sure that writes smaller than 5 bytes don't break the plaintext mode on the server + // once it decides that the input doesn't look like a TLS handshake. + socket.getOutputStream().write(DATA_FROM_CLIENT, 0, 2); + socket.getOutputStream().flush(); + Thread.sleep(TIMEOUT / 2); + socket.getOutputStream().write(DATA_FROM_CLIENT, 2, DATA_FROM_CLIENT.length - 2); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + serverThread.shutdown(TIMEOUT); + forceClose(socket); - class ServerThread extends Thread { + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } + + /** + * Attempting to connect to a SSL-only server without SSL should fail. + */ + @Test + public void testConnectWithoutSSLToStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); + serverThread.start(); + + Socket socket = connectWithoutSSL(); + socket.getOutputStream().write(DATA_FROM_CLIENT); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + try { + socket.getInputStream().read(buf, 0, buf.length); + } catch (SocketException e) { + // We expect the other end to hang up the connection + return; + } finally { + serverThread.shutdown(TIMEOUT); + forceClose(socket); + } + Assert.fail("Expected server to hang up the connection. Read from server succeeded unexpectedly."); + } + + /** + * This test makes sure that UnifiedServerSocket used properly (a single thread accept()-ing connections and + * handing the resulting sockets to other threads for processing) is not vulnerable to a simple denial-of-service + * attack in which a client connects and never writes any bytes. This should not block the accepting thread, since + * the read to determine if the client is sending a TLS handshake or not happens in the processing thread. + * + * This version of the test uses a non-strict server socket (i.e. it accepts both TLS and plaintext connections). + */ + @Test + public void testDenialOfServiceResistanceNonStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( --- End diff -- How about we keep the unit test that's focused strictly on UnifiedServerSocket in this PR, and I'll open a separate PR/Jira for adding a new unit test that tests `Leader`? There is some value in the existing tests, since they did catch issues with the original UnifiedServerSocket implementation.
---