Github user ivmaykov commented on a diff in the pull request:

    https://github.com/apache/zookeeper/pull/679#discussion_r233657773
  
    --- Diff: 
zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java
 ---
    @@ -17,156 +17,644 @@
      */
     package org.apache.zookeeper.server.quorum;
     
    +import java.io.BufferedInputStream;
    +import java.io.IOException;
    +import java.net.ConnectException;
    +import java.net.InetSocketAddress;
    +import java.net.Socket;
    +import java.net.SocketException;
    +import java.util.ArrayList;
    +import java.util.Collection;
    +import java.util.List;
    +import java.util.Random;
    +
    +import javax.net.ssl.HandshakeCompletedEvent;
    +import javax.net.ssl.HandshakeCompletedListener;
    +import javax.net.ssl.SSLSocket;
    +
     import org.apache.zookeeper.PortAssignment;
     import org.apache.zookeeper.client.ZKClientConfig;
    +import org.apache.zookeeper.common.BaseX509ParameterizedTestCase;
     import org.apache.zookeeper.common.ClientX509Util;
    -import org.apache.zookeeper.common.Time;
    +import org.apache.zookeeper.common.KeyStoreFileType;
    +import org.apache.zookeeper.common.X509Exception;
    +import org.apache.zookeeper.common.X509KeyType;
    +import org.apache.zookeeper.common.X509TestContext;
     import org.apache.zookeeper.common.X509Util;
     import org.apache.zookeeper.server.ServerCnxnFactory;
     import org.junit.Assert;
     import org.junit.Before;
     import org.junit.Test;
    +import org.junit.runner.RunWith;
    +import org.junit.runners.Parameterized;
     
    -import javax.net.ssl.HandshakeCompletedEvent;
    -import javax.net.ssl.HandshakeCompletedListener;
    -import javax.net.ssl.SSLSocket;
    -import java.io.IOException;
    -import java.net.ConnectException;
    -import java.net.InetSocketAddress;
    -import java.net.Socket;
    -
    -import static org.hamcrest.CoreMatchers.equalTo;
    -import static org.junit.Assert.assertThat;
    +@RunWith(Parameterized.class)
    +public class UnifiedServerSocketTest extends BaseX509ParameterizedTestCase 
{
     
    -public class UnifiedServerSocketTest {
    +    @Parameterized.Parameters
    +    public static Collection<Object[]> params() {
    +        ArrayList<Object[]> result = new ArrayList<>();
    +        int paramIndex = 0;
    +        for (X509KeyType caKeyType : X509KeyType.values()) {
    +            for (X509KeyType certKeyType : X509KeyType.values()) {
    +                for (Boolean hostnameVerification : new Boolean[] { true, 
false  }) {
    +                    result.add(new Object[]{
    +                            caKeyType,
    +                            certKeyType,
    +                            hostnameVerification,
    +                            paramIndex++
    +                    });
    +                }
    +            }
    +        }
    +        return result;
    +    }
     
         private static final int MAX_RETRIES = 5;
         private static final int TIMEOUT = 1000;
    +    private static final byte[] DATA_TO_CLIENT = "hello client".getBytes();
    +    private static final byte[] DATA_FROM_CLIENT = "hello 
server".getBytes();
     
         private X509Util x509Util;
         private int port;
    -    private volatile boolean handshakeCompleted;
    +    private InetSocketAddress localServerAddress;
    +    private final Object handshakeCompletedLock = new Object();
    +    // access only inside synchronized(handshakeCompletedLock) { ... } 
blocks
    +    private boolean handshakeCompleted = false;
    +
    +    public UnifiedServerSocketTest(
    +            final X509KeyType caKeyType,
    +            final X509KeyType certKeyType,
    +            final Boolean hostnameVerification,
    +            final Integer paramIndex) {
    +        super(paramIndex, () -> {
    +            try {
    +                return X509TestContext.newBuilder()
    +                    .setTempDir(tempDir)
    +                    .setKeyStoreKeyType(certKeyType)
    +                    .setTrustStoreKeyType(caKeyType)
    +                    .setHostnameVerification(hostnameVerification)
    +                    .build();
    +            } catch (Exception e) {
    +                throw new RuntimeException(e);
    +            }
    +        });
    +    }
     
         @Before
         public void setUp() throws Exception {
    -        handshakeCompleted = false;
    -
             port = PortAssignment.unique();
    +        localServerAddress = new InetSocketAddress("localhost", port);
     
    -        String testDataPath = System.getProperty("test.data.dir", 
"build/test/data");
             
System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY, 
"org.apache.zookeeper.server.NettyServerCnxnFactory");
             System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, 
"org.apache.zookeeper.ClientCnxnSocketNetty");
             System.setProperty(ZKClientConfig.SECURE_CLIENT, "true");
     
             x509Util = new ClientX509Util();
     
    -        System.setProperty(x509Util.getSslKeystoreLocationProperty(), 
testDataPath + "/ssl/testKeyStore.jks");
    -        System.setProperty(x509Util.getSslKeystorePasswdProperty(), 
"testpass");
    -        System.setProperty(x509Util.getSslTruststoreLocationProperty(), 
testDataPath + "/ssl/testTrustStore.jks");
    -        System.setProperty(x509Util.getSslTruststorePasswdProperty(), 
"testpass");
    -        
System.setProperty(x509Util.getSslHostnameVerificationEnabledProperty(), 
"false");
    +        x509TestContext.setSystemProperties(x509Util, 
KeyStoreFileType.JKS, KeyStoreFileType.JKS);
         }
     
    -    @Test
    -    public void testConnectWithSSL() throws Exception {
    -        class ServerThread extends Thread {
    -            public void run() {
    -                try {
    -                    Socket unifiedSocket = new 
UnifiedServerSocket(x509Util, port).accept();
    -                    ((SSLSocket)unifiedSocket).getSession(); // block 
until handshake completes
    -                } catch (IOException e) {
    -                    e.printStackTrace();
    +    private static void forceClose(java.io.Closeable s) {
    +        if (s == null) {
    +            return;
    +        }
    +        try {
    +            s.close();
    +        } catch (IOException e) {
    +            e.printStackTrace();
    +        }
    +    }
    +
    +    private static final class UnifiedServerThread extends Thread {
    +        private final byte[] dataToClient;
    +        private List<byte[]> dataFromClients;
    +        private List<Thread> workerThreads;
    +        private UnifiedServerSocket serverSocket;
    +
    +        UnifiedServerThread(X509Util x509Util,
    +                            InetSocketAddress bindAddress,
    +                            boolean allowInsecureConnection,
    +                            byte[] dataToClient) throws IOException {
    +            this.dataToClient = dataToClient;
    +            dataFromClients = new ArrayList<>();
    +            workerThreads = new ArrayList<>();
    +            serverSocket = new UnifiedServerSocket(x509Util, 
allowInsecureConnection);
    +            serverSocket.bind(bindAddress);
    +        }
    +
    +        @Override
    +        public void run() {
    +            try {
    +                Random rnd = new Random();
    +                while (true) {
    +                    final Socket unifiedSocket = serverSocket.accept();
    +                    final boolean tcpNoDelay = rnd.nextBoolean();
    +                    unifiedSocket.setTcpNoDelay(tcpNoDelay);
    +                    unifiedSocket.setSoTimeout(TIMEOUT);
    +                    final boolean keepAlive = rnd.nextBoolean();
    +                    unifiedSocket.setKeepAlive(keepAlive);
    +                    // Note: getting the input stream should not block the 
thread or trigger mode detection.
    +                    BufferedInputStream bis = new 
BufferedInputStream(unifiedSocket.getInputStream());
    +                    Thread t = new Thread(new Runnable() {
    +                        @Override
    +                        public void run() {
    +                            try {
    +                                byte[] buf = new byte[1024];
    +                                int bytesRead = 
unifiedSocket.getInputStream().read(buf, 0, 1024);
    +                                // Make sure the settings applied above 
before the socket was potentially upgraded to
    +                                // TLS still apply.
    +                                Assert.assertEquals(tcpNoDelay, 
unifiedSocket.getTcpNoDelay());
    +                                Assert.assertEquals(TIMEOUT, 
unifiedSocket.getSoTimeout());
    +                                Assert.assertEquals(keepAlive, 
unifiedSocket.getKeepAlive());
    +                                if (bytesRead > 0) {
    +                                    byte[] dataFromClient = new 
byte[bytesRead];
    +                                    System.arraycopy(buf, 0, 
dataFromClient, 0, bytesRead);
    +                                    synchronized (dataFromClients) {
    +                                        
dataFromClients.add(dataFromClient);
    +                                    }
    +                                }
    +                                
unifiedSocket.getOutputStream().write(dataToClient);
    +                                unifiedSocket.getOutputStream().flush();
    +                            } catch (IOException e) {
    +                                e.printStackTrace();
    +                                throw new RuntimeException(e);
    +                            } finally {
    +                                forceClose(unifiedSocket);
    +                            }
    +                        }
    +                    });
    +                    workerThreads.add(t);
    +                    t.start();
                     }
    +            } catch (IOException e) {
    +                e.printStackTrace();
    +                throw new RuntimeException(e);
    +            } finally {
    +                forceClose(serverSocket);
                 }
             }
    -        ServerThread serverThread = new ServerThread();
    -        serverThread.start();
     
    +        public void shutdown(long millis) throws InterruptedException {
    +            forceClose(serverSocket); // this should break the run() loop
    +            for (Thread t : workerThreads) {
    +                t.join(millis);
    +            }
    +            this.join(millis);
    +        }
    +
    +        synchronized byte[] getDataFromClient(int index) {
    +            return dataFromClients.get(index);
    +        }
    +    }
    +
    +    private SSLSocket connectWithSSL() throws IOException, X509Exception, 
InterruptedException {
             SSLSocket sslSocket = null;
             int retries = 0;
             while (retries < MAX_RETRIES) {
                 try {
                     sslSocket = x509Util.createSSLSocket();
    +                sslSocket.addHandshakeCompletedListener(new 
HandshakeCompletedListener() {
    +                    @Override
    +                    public void handshakeCompleted(HandshakeCompletedEvent 
handshakeCompletedEvent) {
    +                        synchronized (handshakeCompletedLock) {
    +                            handshakeCompleted = true;
    +                            handshakeCompletedLock.notifyAll();
    +                        }
    +                    }
    +                });
                     sslSocket.setSoTimeout(TIMEOUT);
    -                sslSocket.connect(new InetSocketAddress(port), TIMEOUT);
    +                sslSocket.connect(localServerAddress, TIMEOUT);
                     break;
                 } catch (ConnectException connectException) {
                     connectException.printStackTrace();
    +                forceClose(sslSocket);
    +                sslSocket = null;
                     Thread.sleep(TIMEOUT);
                 }
                 retries++;
             }
     
    -        sslSocket.addHandshakeCompletedListener(new 
HandshakeCompletedListener() {
    -            @Override
    -            public void handshakeCompleted(HandshakeCompletedEvent 
handshakeCompletedEvent) {
    -                completeHandshake();
    +        Assert.assertNotNull("Failed to connect to server with SSL", 
sslSocket);
    +        return sslSocket;
    +    }
    +
    +    private Socket connectWithoutSSL() throws IOException, 
InterruptedException {
    +        Socket socket = null;
    +        int retries = 0;
    +        while (retries < MAX_RETRIES) {
    +            try {
    +                socket = new Socket();
    +                socket.setSoTimeout(TIMEOUT);
    +                socket.connect(localServerAddress, TIMEOUT);
    +                break;
    +            } catch (ConnectException connectException) {
    +                connectException.printStackTrace();
    +                forceClose(socket);
    +                socket = null;
    +                Thread.sleep(TIMEOUT);
                 }
    -        });
    -        sslSocket.startHandshake();
    +            retries++;
    +        }
    +        Assert.assertNotNull("Failed to connect to server without SSL", 
socket);
    +        return socket;
    +    }
    +
    +    // In the tests below, a "Strict" server means a UnifiedServerSocket 
that
    +    // does not allow plaintext connections (in other words, it's 
SSL-only).
    +    // A "Non Strict" server means a UnifiedServerSocket that allows both
    +    // plaintext and SSL incoming connections.
    +
    +    /**
    +     * Attempting to connect to a SSL-or-plaintext server with SSL should 
work.
    +     */
    +    @Test
    +    public void testConnectWithSSLToNonStrictServer() throws Exception {
    +        UnifiedServerThread serverThread = new UnifiedServerThread(
    +                x509Util, localServerAddress, true, DATA_TO_CLIENT);
    +        serverThread.start();
     
    -        serverThread.join(TIMEOUT);
    +        Socket sslSocket = connectWithSSL();
    +        sslSocket.getOutputStream().write(DATA_FROM_CLIENT);
    +        sslSocket.getOutputStream().flush();
    +        byte[] buf = new byte[DATA_TO_CLIENT.length];
    +        int bytesRead = sslSocket.getInputStream().read(buf, 0, 
buf.length);
    +        Assert.assertEquals(buf.length, bytesRead);
    +        Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
     
    -        long start = Time.currentElapsedTime();
    -        while (Time.currentElapsedTime() < start + TIMEOUT) {
    -            if (handshakeCompleted) {
    -                return;
    +        serverThread.shutdown(TIMEOUT);
    +        forceClose(sslSocket);
    +
    +        synchronized (handshakeCompletedLock) {
    +            if (!handshakeCompleted) {
    +                handshakeCompletedLock.wait(TIMEOUT);
                 }
    +            Assert.assertTrue(handshakeCompleted);
             }
    +        Assert.assertArrayEquals(DATA_FROM_CLIENT, 
serverThread.getDataFromClient(0));
    +    }
    +
    +    /**
    +     * Attempting to connect to a SSL-only server with SSL should work.
    +     */
    +    @Test
    +    public void testConnectWithSSLToStrictServer() throws Exception {
    +        UnifiedServerThread serverThread = new UnifiedServerThread(
    +                x509Util, localServerAddress, false, DATA_TO_CLIENT);
    +        serverThread.start();
    +
    +        Socket sslSocket = connectWithSSL();
    +        sslSocket.getOutputStream().write(DATA_FROM_CLIENT);
    +        sslSocket.getOutputStream().flush();
    +        byte[] buf = new byte[DATA_TO_CLIENT.length];
    +        int bytesRead = sslSocket.getInputStream().read(buf, 0, 
buf.length);
    +        Assert.assertEquals(buf.length, bytesRead);
    +        Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
     
    -        Assert.fail("failed to complete handshake");
    +        serverThread.shutdown(TIMEOUT);
    +        forceClose(sslSocket);
    +
    +        synchronized (handshakeCompletedLock) {
    +            if (!handshakeCompleted) {
    +                handshakeCompletedLock.wait(TIMEOUT);
    +            }
    +            Assert.assertTrue(handshakeCompleted);
    +        }
    +
    +        Assert.assertArrayEquals(DATA_FROM_CLIENT, 
serverThread.getDataFromClient(0));
         }
     
    -    private void completeHandshake() {
    -        handshakeCompleted = true;
    +    /**
    +     * Attempting to connect to a SSL-or-plaintext server without SSL 
should work.
    +     */
    +    @Test
    +    public void testConnectWithoutSSLToNonStrictServer() throws Exception {
    +        UnifiedServerThread serverThread = new UnifiedServerThread(
    +                x509Util, localServerAddress, true, DATA_TO_CLIENT);
    +        serverThread.start();
    +
    +        Socket socket = connectWithoutSSL();
    +        socket.getOutputStream().write(DATA_FROM_CLIENT);
    +        socket.getOutputStream().flush();
    +        byte[] buf = new byte[DATA_TO_CLIENT.length];
    +        int bytesRead = socket.getInputStream().read(buf, 0, buf.length);
    +        Assert.assertEquals(buf.length, bytesRead);
    +        Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
    +
    +        serverThread.shutdown(TIMEOUT);
    +        forceClose(socket);
    +
    +        Assert.assertArrayEquals(DATA_FROM_CLIENT, 
serverThread.getDataFromClient(0));
         }
     
    +    /**
    +     * Attempting to connect to a SSL-or-plaintext server without SSL with 
a
    +     * small initial data write should work. This makes sure that sending
    +     * less than 5 bytes does not break the logic in the server's initial 5
    +     * byte read.
    +     */
         @Test
    -    public void testConnectWithoutSSL() throws Exception {
    -        final byte[] testData = "hello there".getBytes();
    -        final String[] dataReadFromClient = {null};
    +    public void testConnectWithoutSSLToNonStrictServerPartialWrite() 
throws Exception {
    +        UnifiedServerThread serverThread = new UnifiedServerThread(
    +                x509Util, localServerAddress, true, DATA_TO_CLIENT);
    +        serverThread.start();
    +
    +        Socket socket = connectWithoutSSL();
    +        // Write only 2 bytes of the message, wait a bit, then write the 
rest.
    +        // This makes sure that writes smaller than 5 bytes don't break 
the plaintext mode on the server
    +        // once it decides that the input doesn't look like a TLS 
handshake.
    +        socket.getOutputStream().write(DATA_FROM_CLIENT, 0, 2);
    +        socket.getOutputStream().flush();
    +        Thread.sleep(TIMEOUT / 2);
    +        socket.getOutputStream().write(DATA_FROM_CLIENT, 2, 
DATA_FROM_CLIENT.length - 2);
    +        socket.getOutputStream().flush();
    +        byte[] buf = new byte[DATA_TO_CLIENT.length];
    +        int bytesRead = socket.getInputStream().read(buf, 0, buf.length);
    +        Assert.assertEquals(buf.length, bytesRead);
    +        Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
    +
    +        serverThread.shutdown(TIMEOUT);
    +        forceClose(socket);
     
    -        class ServerThread extends Thread {
    +        Assert.assertArrayEquals(DATA_FROM_CLIENT, 
serverThread.getDataFromClient(0));
    +    }
    +
    +    /**
    +     * Attempting to connect to a SSL-only server without SSL should fail.
    +     */
    +    @Test
    +    public void testConnectWithoutSSLToStrictServer() throws Exception {
    +        UnifiedServerThread serverThread = new UnifiedServerThread(
    +                x509Util, localServerAddress, false, DATA_TO_CLIENT);
    +        serverThread.start();
    +
    +        Socket socket = connectWithoutSSL();
    +        socket.getOutputStream().write(DATA_FROM_CLIENT);
    +        socket.getOutputStream().flush();
    +        byte[] buf = new byte[DATA_TO_CLIENT.length];
    +        try {
    +            socket.getInputStream().read(buf, 0, buf.length);
    +        } catch (SocketException e) {
    +            // We expect the other end to hang up the connection
    +            return;
    +        } finally {
    +            serverThread.shutdown(TIMEOUT);
    +            forceClose(socket);
    +        }
    +        Assert.fail("Expected server to hang up the connection. Read from 
server succeeded unexpectedly.");
    +    }
    +
    +    /**
    +     * This test makes sure that UnifiedServerSocket used properly (a 
single thread accept()-ing connections and
    +     * handing the resulting sockets to other threads for processing) is 
not vulnerable to a simple denial-of-service
    +     * attack in which a client connects and never writes any bytes. This 
should not block the accepting thread, since
    +     * the read to determine if the client is sending a TLS handshake or 
not happens in the processing thread.
    +     *
    +     * This version of the test uses a non-strict server socket (i.e. it 
accepts both TLS and plaintext connections).
    +     */
    +    @Test
    +    public void testDenialOfServiceResistanceNonStrictServer() throws 
Exception {
    +        UnifiedServerThread serverThread = new UnifiedServerThread(
    --- End diff --
    
    This is true. I tried to make the threading behavior in this test as 
similar as possible to `Leader`. I don't think Leader should be involved in 
this unit test, but perhaps a similar "denial of service resistance" test case 
can be added to some quorum test. Do you know which unit test would be a good 
place for such a test case?


---

Reply via email to