ZOOKEEPER-3172: Quorum TLS - fix port unification to allow rolling upgrades
Fix numerous problems with UnifiedServerSocket, such as hanging the accept() thread when the client doesn't send any data or crashing if less than 5 bytes are read from the socket in the initial read. Re-enable the "portUnification" config option. ## Fixed networking issues/bugs in UnifiedServerSocket - don't crash the `accept()` thread if the client closes the connection without sending any data - don't corrupt the connection if the client sends fewer than 5 bytes for the initial read - delay the detection of TLS vs. plaintext mode until a socket stream is read from or written to. This prevents the `accept()` thread from getting blocked on a `read()` operation from the newly connected socket. - prepending 5 bytes to `PrependableSocket` and then trying to read >5 bytes would only return the first 5 bytes, even if more bytes were available. This is fixed. Author: Ilya Maykov <[email protected]> Reviewers: [email protected] Closes #679 from ivmaykov/ZOOKEEPER-3172 Project: http://git-wip-us.apache.org/repos/asf/zookeeper/repo Commit: http://git-wip-us.apache.org/repos/asf/zookeeper/commit/64104eae Tree: http://git-wip-us.apache.org/repos/asf/zookeeper/tree/64104eae Diff: http://git-wip-us.apache.org/repos/asf/zookeeper/diff/64104eae Branch: refs/heads/master Commit: 64104eaeaa6508f052edfd39c24243a8e26039dc Parents: 91c6cb2 Author: Ilya Maykov <[email protected]> Authored: Tue Nov 27 10:02:24 2018 +0100 Committer: Andor Molnar <[email protected]> Committed: Tue Nov 27 10:02:24 2018 +0100 ---------------------------------------------------------------------- .../org/apache/zookeeper/common/X509Util.java | 55 +- .../org/apache/zookeeper/common/ZKConfig.java | 2 + .../apache/zookeeper/server/quorum/Leader.java | 29 +- .../apache/zookeeper/server/quorum/Learner.java | 9 +- .../server/quorum/PrependableSocket.java | 29 +- .../server/quorum/QuorumCnxManager.java | 34 +- .../zookeeper/server/quorum/QuorumPeer.java | 8 + .../server/quorum/QuorumPeerConfig.java | 5 +- .../server/quorum/UnifiedServerSocket.java | 738 ++++++++++++++++++- .../apache/zookeeper/common/X509UtilTest.java | 28 + .../zookeeper/server/quorum/QuorumSSLTest.java | 2 - .../UnifiedServerSocketModeDetectionTest.java | 404 ++++++++++ .../server/quorum/UnifiedServerSocketTest.java | 608 ++++++++++++--- 13 files changed, 1794 insertions(+), 157 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java index 5b97ac6..e3625a5 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java @@ -18,6 +18,7 @@ package org.apache.zookeeper.common; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.Socket; import java.security.GeneralSecurityException; @@ -74,6 +75,8 @@ public abstract class X509Util { "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256" }; + public static final int DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS = 5000; + private String sslProtocolProperty = getConfigPrefix() + "protocol"; private String cipherSuitesProperty = getConfigPrefix() + "ciphersuites"; private String sslKeystoreLocationProperty = getConfigPrefix() + "keyStore.location"; @@ -85,6 +88,7 @@ public abstract class X509Util { private String sslHostnameVerificationEnabledProperty = getConfigPrefix() + "hostnameVerification"; private String sslCrlEnabledProperty = getConfigPrefix() + "crl"; private String sslOcspEnabledProperty = getConfigPrefix() + "ocsp"; + private String sslHandshakeDetectionTimeoutMillisProperty = getConfigPrefix() + "handshakeDetectionTimeoutMillis"; private String[] cipherSuites; @@ -146,6 +150,16 @@ public abstract class X509Util { return sslOcspEnabledProperty; } + /** + * Returns the config property key that controls the amount of time, in milliseconds, that the first + * UnifiedServerSocket read operation will block for when trying to detect the client mode (TLS or PLAINTEXT). + * + * @return the config property key. + */ + public String getSslHandshakeDetectionTimeoutMillisProperty() { + return sslHandshakeDetectionTimeoutMillisProperty; + } + public SSLContext getDefaultSSLContext() throws X509Exception.SSLContextException { SSLContext result = defaultSSLContext.get(); if (result == null) { @@ -168,6 +182,31 @@ public abstract class X509Util { return createSSLContext(config); } + /** + * Returns the max amount of time, in milliseconds, that the first UnifiedServerSocket read() operation should + * block for when trying to detect the client mode (TLS or PLAINTEXT). + * Defaults to {@link X509Util#DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS}. + * + * @return the handshake detection timeout, in milliseconds. + */ + public int getSslHandshakeTimeoutMillis() { + String propertyString = System.getProperty(getSslHandshakeDetectionTimeoutMillisProperty()); + int result; + if (propertyString == null) { + result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; + } else { + result = Integer.parseInt(propertyString); + if (result < 1) { + // Timeout of 0 is not allowed, since an infinite timeout can permanently lock up an + // accept() thread. + LOG.warn("Invalid value for " + getSslHandshakeDetectionTimeoutMillisProperty() + ": " + result + + ", using the default value of " + DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS); + result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; + } + } + return result; + } + public SSLContext createSSLContext(ZKConfig config) throws SSLContextException { KeyManager[] keyManagers = null; TrustManager[] trustManagers = null; @@ -350,14 +389,22 @@ public abstract class X509Util { public SSLSocket createSSLSocket() throws X509Exception, IOException { SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(); configureSSLSocket(sslSocket); - + sslSocket.setUseClientMode(true); return sslSocket; } - public SSLSocket createSSLSocket(Socket socket) throws X509Exception, IOException { - SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(socket, null, socket.getPort(), true); + public SSLSocket createSSLSocket(Socket socket, byte[] pushbackBytes) throws X509Exception, IOException { + SSLSocket sslSocket; + if (pushbackBytes != null && pushbackBytes.length > 0) { + sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket( + socket, new ByteArrayInputStream(pushbackBytes), true); + } else { + sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket( + socket, null, socket.getPort(), true); + } configureSSLSocket(sslSocket); - + sslSocket.setUseClientMode(false); + sslSocket.setNeedClientAuth(true); return sslSocket; } http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java index 01bac69..effc0d5 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java @@ -130,6 +130,8 @@ public class ZKConfig { System.getProperty(x509Util.getSslCrlEnabledProperty())); properties.put(x509Util.getSslOcspEnabledProperty(), System.getProperty(x509Util.getSslOcspEnabledProperty())); + properties.put(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), + System.getProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty())); } /** http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java index 9270548..0a892b1 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java @@ -42,7 +42,6 @@ import java.util.concurrent.ConcurrentMap; import javax.security.sasl.SaslException; import org.apache.zookeeper.ZooDefs.OpCode; -import org.apache.zookeeper.common.QuorumX509Util; import org.apache.zookeeper.common.Time; import org.apache.zookeeper.common.X509Exception; import org.apache.zookeeper.server.FinalRequestProcessor; @@ -240,15 +239,15 @@ public class Leader { try { if (self.shouldUsePortUnification()) { if (self.getQuorumListenOnAllIPs()) { - ss = new UnifiedServerSocket(new QuorumX509Util(), self.getQuorumAddress().getPort()); + ss = new UnifiedServerSocket(self.getX509Util(), true, self.getQuorumAddress().getPort()); } else { - ss = new UnifiedServerSocket(new QuorumX509Util()); + ss = new UnifiedServerSocket(self.getX509Util(), true); } } else if (self.isSslQuorum()) { if (self.getQuorumListenOnAllIPs()) { - ss = new QuorumX509Util().createSSLServerSocket(self.getQuorumAddress().getPort()); + ss = self.getX509Util().createSSLServerSocket(self.getQuorumAddress().getPort()); } else { - ss = new QuorumX509Util().createSSLServerSocket(); + ss = self.getX509Util().createSSLServerSocket(); } } else { if (self.getQuorumListenOnAllIPs()) { @@ -399,8 +398,10 @@ public class Leader { public void run() { try { while (!stop) { - try{ - Socket s = ss.accept(); + Socket s = null; + boolean error = false; + try { + s = ss.accept(); // start with the initLimit, once the ack is processed // in LearnerHandler switch to the syncLimit @@ -412,6 +413,7 @@ public class Leader { LearnerHandler fh = new LearnerHandler(s, is, Leader.this); fh.start(); } catch (SocketException e) { + error = true; if (stop) { LOG.info("exception while shutting down acceptor: " + e); @@ -425,6 +427,19 @@ public class Leader { } } catch (SaslException e){ LOG.error("Exception while connecting to quorum learner", e); + error = true; + } catch (Exception e) { + error = true; + throw e; + } finally { + // Don't leak sockets on errors + if (error && s != null && !s.isClosed()) { + try { + s.close(); + } catch (IOException e) { + LOG.warn("Error closing socket", e); + } + } } } } catch (Exception e) { http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java index c740d53..faaa844 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java @@ -38,9 +38,7 @@ import org.apache.jute.BinaryOutputArchive; import org.apache.jute.InputArchive; import org.apache.jute.OutputArchive; import org.apache.jute.Record; -import org.apache.zookeeper.common.QuorumX509Util; import org.apache.zookeeper.common.X509Exception; -import org.apache.zookeeper.common.X509Util; import org.apache.zookeeper.server.ExitCode; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -74,8 +72,6 @@ public class Learner { protected Socket sock; - protected X509Util x509Util; - /** * Socket getter * @return @@ -304,10 +300,7 @@ public class Learner { private Socket createSocket() throws X509Exception, IOException { Socket sock; if (self.isSslQuorum()) { - if (x509Util == null) { - x509Util = new QuorumX509Util(); - } - sock = x509Util.createSSLSocket(); + sock = self.getX509Util().createSSLSocket(); } else { sock = new Socket(); } http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java index a86608f..94a526e 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java @@ -18,16 +18,15 @@ package org.apache.zookeeper.server.quorum; -import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.io.SequenceInputStream; +import java.io.PushbackInputStream; import java.net.Socket; import java.net.SocketImpl; public class PrependableSocket extends Socket { - private SequenceInputStream sequenceInputStream; + private PushbackInputStream pushbackInputStream; public PrependableSocket(SocketImpl base) throws IOException { super(base); @@ -35,15 +34,31 @@ public class PrependableSocket extends Socket { @Override public InputStream getInputStream() throws IOException { - if (sequenceInputStream == null) { + if (pushbackInputStream == null) { return super.getInputStream(); } - return sequenceInputStream; + return pushbackInputStream; } - public void prependToInputStream(byte[] bytes) throws IOException { - sequenceInputStream = new SequenceInputStream(new ByteArrayInputStream(bytes), getInputStream()); + /** + * Prepend some bytes that have already been read back to the socket's input stream. Note that this method can be + * called at most once with a non-0 length per socket instance. + * @param bytes the bytes to prepend. + * @param offset offset in the byte array to start at. + * @param length number of bytes to prepend. + * @throws IOException if this method was already called on the socket instance, or if super.getInputStream() throws. + */ + public void prependToInputStream(byte[] bytes, int offset, int length) throws IOException { + if (length == 0) { + return; // nothing to prepend + } + if (pushbackInputStream != null) { + throw new IOException("prependToInputStream() called more than once"); + } + PushbackInputStream pushbackInputStream = new PushbackInputStream(getInputStream(), length); + pushbackInputStream.unread(bytes, offset, length); + this.pushbackInputStream = pushbackInputStream; } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java index 8b91023..4175f3c 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java @@ -47,9 +47,7 @@ import java.util.NoSuchElementException; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -import org.apache.zookeeper.common.QuorumX509Util; import org.apache.zookeeper.common.X509Exception; -import org.apache.zookeeper.common.X509Util; import org.apache.zookeeper.server.ExitCode; import org.apache.zookeeper.server.quorum.QuorumPeerConfig.ConfigException; import org.apache.zookeeper.server.util.ConfigUtils; @@ -175,9 +173,6 @@ public class QuorumCnxManager { */ private final boolean tcpKeepAlive = Boolean.getBoolean("zookeeper.tcpKeepAlive"); - - private X509Util x509Util; - static public class Message { Message(ByteBuffer buffer, long sid) { this.buffer = buffer; @@ -291,8 +286,6 @@ public class QuorumCnxManager { // Starts listener thread that waits for connection requests listener = new Listener(); listener.setName("QuorumPeerListener"); - - x509Util = new QuorumX509Util(); } private void initializeAuth(final long mySid, @@ -655,17 +648,18 @@ public class QuorumCnxManager { try { LOG.debug("Opening channel to server " + sid); if (self.isSslQuorum()) { - SSLSocket sslSock = x509Util.createSSLSocket(); - setSockOpts(sslSock); - sslSock.connect(electionAddr, cnxTO); - sslSock.startHandshake(); - sock = sslSock; - } else { - sock = new Socket(); - setSockOpts(sock); - sock.connect(electionAddr, cnxTO); - } - LOG.debug("Connected to server " + sid); + SSLSocket sslSock = self.getX509Util().createSSLSocket(); + setSockOpts(sslSock); + sslSock.connect(electionAddr, cnxTO); + sslSock.startHandshake(); + sock = sslSock; + } else { + sock = new Socket(); + setSockOpts(sock); + sock.connect(electionAddr, cnxTO); + + } + LOG.debug("Connected to server " + sid); // Sends connection request asynchronously if the quorum // sasl authentication is enabled. This is required because // sasl server authentication process may take few seconds to @@ -876,9 +870,9 @@ public class QuorumCnxManager { while((!shutdown) && (numRetries < 3)){ try { if (self.shouldUsePortUnification()) { - ss = new UnifiedServerSocket(x509Util); + ss = new UnifiedServerSocket(self.getX509Util(), true); } else if (self.isSslQuorum()) { - ss = x509Util.createSSLServerSocket(); + ss = self.getX509Util().createSSLServerSocket(); } else { ss = new ServerSocket(); } http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java index 136a538..7abde4b 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java @@ -47,6 +47,7 @@ import javax.security.sasl.SaslException; import org.apache.zookeeper.KeeperException.BadArgumentsException; import org.apache.zookeeper.common.AtomicFileWritingIdiom; import org.apache.zookeeper.common.AtomicFileWritingIdiom.WriterStatement; +import org.apache.zookeeper.common.QuorumX509Util; import org.apache.zookeeper.common.Time; import org.apache.zookeeper.common.X509Exception; import org.apache.zookeeper.jmx.MBeanRegistry; @@ -479,6 +480,12 @@ public class QuorumPeer extends ZooKeeperThread implements QuorumStats.Provider return shouldUsePortUnification; } + private final QuorumX509Util x509Util; + + QuorumX509Util getX509Util() { + return x509Util; + } + /** * This is who I think the leader currently is. */ @@ -801,6 +808,7 @@ public class QuorumPeer extends ZooKeeperThread implements QuorumStats.Provider quorumStats = new QuorumStats(this); jmxRemotePeerBean = new HashMap<Long, RemotePeerBean>(); adminServer = AdminServerFactory.createAdminServer(); + x509Util = new QuorumX509Util(); initialize(); } http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java index 45463b1..aee5efc 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java @@ -315,9 +315,8 @@ public class QuorumPeerConfig { } } else if (key.equals("sslQuorum")){ sslQuorum = Boolean.parseBoolean(value); -// TODO: UnifiedServerSocket is currently buggy, will be fixed when @ivmaykov's PRs are merged. Disable port unification until then. -// } else if (key.equals("portUnification")){ -// shouldUsePortUnification = Boolean.parseBoolean(value); + } else if (key.equals("portUnification")){ + shouldUsePortUnification = Boolean.parseBoolean(value); } else if ((key.startsWith("server.") || key.startsWith("group") || key.startsWith("weight")) && zkProp.containsKey("dynamicConfigFile")) { throw new ConfigException("parameter: " + key + " must be in a separate dynamic config file"); } else if (key.equals(QuorumAuth.QUORUM_SASL_AUTH_ENABLED)) { http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java index d1e3ba5..bbe245f 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java @@ -27,23 +27,111 @@ import org.slf4j.LoggerFactory; import javax.net.ssl.SSLSocket; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetAddress; import java.net.ServerSocket; import java.net.Socket; +import java.net.SocketAddress; import java.net.SocketException; +import java.net.SocketTimeoutException; +import java.nio.channels.SocketChannel; +/** + * A ServerSocket that can act either as a regular ServerSocket, as a SSLServerSocket, or as both, depending on + * the constructor parameters and on the type of client (TLS or plaintext) that connects to it. + * The constructors have the same signature as constructors of ServerSocket, with the addition of two parameters + * at the beginning: + * <ul> + * <li>X509Util - provides the SSL context to construct a secure socket when a client connects with TLS.</li> + * <li>boolean allowInsecureConnection - when true, acts as a hybrid server socket (plaintext / TLS). When + * false, acts as a SSLServerSocket (rejects plaintext connections).</li> + * </ul> + * The <code>!allowInsecureConnection</code> mode is needed so we can update the SSLContext (in particular, the + * key store and/or trust store) without having to re-create the server socket. By starting with a plaintext socket + * and delaying the upgrade to TLS until after a client has connected and begins a handshake, we can keep the same + * UnifiedServerSocket instance around, and replace the default SSLContext in the provided X509Util when the key store + * and/or trust store file changes on disk. + */ public class UnifiedServerSocket extends ServerSocket { private static final Logger LOG = LoggerFactory.getLogger(UnifiedServerSocket.class); private X509Util x509Util; + private final boolean allowInsecureConnection; - public UnifiedServerSocket(X509Util x509Util) throws IOException { + /** + * Creates an unbound unified server socket by calling {@link ServerSocket#ServerSocket()}. + * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a + * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of + * the <code>allowInsecureConnection</code> parameter. + * @param x509Util the X509Util that provides the SSLContext to use for secure connections. + * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them. + * @throws IOException if {@link ServerSocket#ServerSocket()} throws. + */ + public UnifiedServerSocket(X509Util x509Util, boolean allowInsecureConnection) throws IOException { super(); this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; } - public UnifiedServerSocket(X509Util x509Util, int port) throws IOException { + /** + * Creates a unified server socket bound to the specified port by calling {@link ServerSocket#ServerSocket(int)}. + * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a + * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of + * the <code>allowInsecureConnection</code> parameter. + * @param x509Util the X509Util that provides the SSLContext to use for secure connections. + * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them. + * @param port the port number, or {@code 0} to use a port number that is automatically allocated. + * @throws IOException if {@link ServerSocket#ServerSocket(int)} throws. + */ + public UnifiedServerSocket(X509Util x509Util, boolean allowInsecureConnection, int port) throws IOException { super(port); this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; + } + + /** + * Creates a unified server socket bound to the specified port, with the specified backlog, by calling + * {@link ServerSocket#ServerSocket(int, int)}. + * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a + * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of + * the <code>allowInsecureConnection</code> parameter. + * @param x509Util the X509Util that provides the SSLContext to use for secure connections. + * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them. + * @param port the port number, or {@code 0} to use a port number that is automatically allocated. + * @param backlog requested maximum length of the queue of incoming connections. + * @throws IOException if {@link ServerSocket#ServerSocket(int, int)} throws. + */ + public UnifiedServerSocket(X509Util x509Util, + boolean allowInsecureConnection, + int port, + int backlog) throws IOException { + super(port, backlog); + this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; + } + + /** + * Creates a unified server socket bound to the specified port, with the specified backlog, and local IP address + * to bind to, by calling {@link ServerSocket#ServerSocket(int, int, InetAddress)}. + * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a + * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of + * the <code>allowInsecureConnection</code> parameter. + * @param x509Util the X509Util that provides the SSLContext to use for secure connections. + * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them. + * @param port the port number, or {@code 0} to use a port number that is automatically allocated. + * @param backlog requested maximum length of the queue of incoming connections. + * @param bindAddr the local InetAddress the server will bind to. + * @throws IOException if {@link ServerSocket#ServerSocket(int, int, InetAddress)} throws. + */ + public UnifiedServerSocket(X509Util x509Util, + boolean allowInsecureConnection, + int port, + int backlog, + InetAddress bindAddr) throws IOException { + super(port, backlog, bindAddr); + this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; } @Override @@ -56,24 +144,642 @@ public class UnifiedServerSocket extends ServerSocket { } final PrependableSocket prependableSocket = new PrependableSocket(null); implAccept(prependableSocket); + return new UnifiedSocket(x509Util, allowInsecureConnection, prependableSocket); + } + + /** + * The result of calling accept() on a UnifiedServerSocket. This is a Socket that doesn't know if it's + * using plaintext or SSL/TLS at the time when it is created. Calling a method that indicates a desire to + * read or write from the socket will cause the socket to detect if the connected client is attempting + * to establish a TLS or plaintext connection. This is done by doing a blocking read of 5 bytes off the + * socket and checking if the bytes look like the start of a TLS ClientHello message. If it looks like + * the client is attempting to connect with TLS, the internal socket is upgraded to a SSLSocket. If not, + * any bytes read from the socket are pushed back to the input stream, and the socket continues + * to be treated as a plaintext socket. + * + * The methods that trigger this behavior are: + * <ul> + * <li>{@link UnifiedSocket#getInputStream()}</li> + * <li>{@link UnifiedSocket#getOutputStream()}</li> + * <li>{@link UnifiedSocket#sendUrgentData(int)}</li> + * </ul> + * + * Calling other socket methods (i.e option setters such as {@link Socket#setTcpNoDelay(boolean)}) does + * not trigger mode detection. + * + * Because detecting the mode is a potentially blocking operation, it should not be done in the + * accepting thread. Attempting to read from or write to the socket in the accepting thread opens the + * caller up to a denial-of-service attack, in which a client connects and then does nothing. This would + * prevent any other clients from connecting. Passing the socket returned by accept() to a separate + * thread which handles all read and write operations protects against this DoS attack. + * + * Callers can check if the socket has been upgraded to TLS by calling {@link UnifiedSocket#isSecureSocket()}, + * and can get the underlying SSLSocket by calling {@link UnifiedSocket#getSslSocket()}. + */ + public static class UnifiedSocket extends Socket { + private enum Mode { + UNKNOWN, + PLAINTEXT, + TLS + } - byte[] litmus = new byte[5]; - int bytesRead = prependableSocket.getInputStream().read(litmus, 0, 5); - prependableSocket.prependToInputStream(litmus); + private final X509Util x509Util; + private final boolean allowInsecureConnection; + private PrependableSocket prependableSocket; + private SSLSocket sslSocket; + private Mode mode; - if (bytesRead == 5 && SslHandler.isEncrypted(Unpooled.wrappedBuffer(litmus))) { - LOG.info(getInetAddress() + " attempting to connect over ssl"); - SSLSocket sslSocket; + /** + * Note: this constructor is intentionally private. The only intended caller is + * {@link UnifiedServerSocket#accept()}. + * + * @param x509Util + * @param allowInsecureConnection + * @param prependableSocket + */ + private UnifiedSocket(X509Util x509Util, boolean allowInsecureConnection, PrependableSocket prependableSocket) { + this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; + this.prependableSocket = prependableSocket; + this.sslSocket = null; + this.mode = Mode.UNKNOWN; + } + + /** + * Returns true if the socket mode has been determined to be TLS. + * @return true if the mode is TLS, false if it is UNKNOWN or PLAINTEXT. + */ + public boolean isSecureSocket() { + return mode == Mode.TLS; + } + + /** + * Returns true if the socket mode has been determined to be PLAINTEXT. + * @return true if the mode is PLAINTEXT, false if it is UNKNOWN or TLS. + */ + public boolean isPlaintextSocket() { + return mode == Mode.PLAINTEXT; + } + + /** + * Returns true if the socket mode is not yet known. + * @return true if the mode is UNKNOWN, false if it is PLAINTEXT or TLS. + */ + public boolean isModeKnown() { + return mode != Mode.UNKNOWN; + } + + /** + * Detects the socket mode, see comments at the top of the class for more details. This operation will block + * for up to {@link X509Util#getSslHandshakeTimeoutMillis()} milliseconds and should not be called in the + * accept() thread if possible. + * @throws IOException + */ + private void detectMode() throws IOException { + byte[] litmus = new byte[5]; + int oldTimeout = -1; + int bytesRead = 0; + int newTimeout = x509Util.getSslHandshakeTimeoutMillis(); try { - sslSocket = x509Util.createSSLSocket(prependableSocket); - } catch (X509Exception e) { - throw new IOException("failed to create SSL context", e); + oldTimeout = prependableSocket.getSoTimeout(); + prependableSocket.setSoTimeout(newTimeout); + bytesRead = prependableSocket.getInputStream().read(litmus, 0, litmus.length); + } catch (SocketTimeoutException e) { + // Didn't read anything within the timeout, fallthrough and assume the connection is plaintext. + LOG.warn("Socket mode detection timed out after " + newTimeout + " ms, assuming PLAINTEXT"); + } finally { + // restore socket timeout to the old value + try { + if (oldTimeout != -1) { + prependableSocket.setSoTimeout(oldTimeout); + } + } catch (Exception e) { + LOG.warn("Failed to restore old socket timeout value of " + oldTimeout + " ms", e); + } + } + if (bytesRead < 0) { // Got a EOF right away, definitely not using TLS. Fallthrough. + bytesRead = 0; + } + + if (bytesRead == litmus.length && SslHandler.isEncrypted(Unpooled.wrappedBuffer(litmus))) { + try { + sslSocket = x509Util.createSSLSocket(prependableSocket, litmus); + } catch (X509Exception e) { + throw new IOException("failed to create SSL context", e); + } + prependableSocket = null; + mode = Mode.TLS; + } else if (allowInsecureConnection) { + prependableSocket.prependToInputStream(litmus, 0, bytesRead); + mode = Mode.PLAINTEXT; + } else { + prependableSocket.close(); + mode = Mode.PLAINTEXT; + throw new IOException("Blocked insecure connection attempt"); + } + } + + private Socket getSocketAllowUnknownMode() { + if (isSecureSocket()) { + return sslSocket; + } else { // Note: mode is UNKNOWN or PLAINTEXT + return prependableSocket; + } + } + + /** + * Returns the underlying socket, detecting the socket mode if it is not yet known. This is a potentially + * blocking operation and should not be called in the accept() thread. + * @return the underlying socket, after the socket mode has been determined. + * @throws IOException + */ + private Socket getSocket() throws IOException { + if (!isModeKnown()) { + detectMode(); + } + if (mode == Mode.TLS) { + return sslSocket; + } else { + return prependableSocket; + } + } + + /** + * Returns the underlying SSLSocket if the mode is TLS. If the mode is UNKNOWN, causes mode detection which is a + * potentially blocking operation. If the mode ends up being PLAINTEXT, this will throw a SocketException, so + * callers are advised to only call this method after checking that {@link UnifiedSocket#isSecureSocket()} + * returned true. + * @return the underlying SSLSocket if the mode is known to be TLS. + * @throws IOException if detecting the socket mode fails + * @throws SocketException if the mode is PLAINTEXT. + */ + public SSLSocket getSslSocket() throws IOException { + if (!isModeKnown()) { + detectMode(); + } + if (!isSecureSocket()) { + throw new SocketException("Socket mode is not TLS"); } - sslSocket.setUseClientMode(false); return sslSocket; - } else { - LOG.info(getInetAddress() + " attempting to connect without ssl"); - return prependableSocket; } + + /** + * See {@link Socket#connect(SocketAddress)}. Calling this method does not trigger mode detection. + */ + @Override + public void connect(SocketAddress endpoint) throws IOException { + getSocketAllowUnknownMode().connect(endpoint); + } + + /** + * See {@link Socket#connect(SocketAddress, int)}. Calling this method does not trigger mode detection. + */ + @Override + public void connect(SocketAddress endpoint, int timeout) throws IOException { + getSocketAllowUnknownMode().connect(endpoint, timeout); + } + + /** + * See {@link Socket#bind(SocketAddress)}. Calling this method does not trigger mode detection. + */ + @Override + public void bind(SocketAddress bindpoint) throws IOException { + getSocketAllowUnknownMode().bind(bindpoint); + } + + /** + * See {@link Socket#getInetAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public InetAddress getInetAddress() { + return getSocketAllowUnknownMode().getInetAddress(); + } + + /** + * See {@link Socket#getLocalAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public InetAddress getLocalAddress() { + return getSocketAllowUnknownMode().getLocalAddress(); + } + + /** + * See {@link Socket#getPort()}. Calling this method does not trigger mode detection. + */ + @Override + public int getPort() { + return getSocketAllowUnknownMode().getPort(); + } + + /** + * See {@link Socket#getLocalPort()}. Calling this method does not trigger mode detection. + */ + @Override + public int getLocalPort() { + return getSocketAllowUnknownMode().getLocalPort(); + } + + /** + * See {@link Socket#getRemoteSocketAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public SocketAddress getRemoteSocketAddress() { + return getSocketAllowUnknownMode().getRemoteSocketAddress(); + } + + /** + * See {@link Socket#getLocalSocketAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public SocketAddress getLocalSocketAddress() { + return getSocketAllowUnknownMode().getLocalSocketAddress(); + } + + /** + * See {@link Socket#getChannel()}. Calling this method does not trigger mode detection. + */ + @Override + public SocketChannel getChannel() { + return getSocketAllowUnknownMode().getChannel(); + } + + /** + * See {@link Socket#getInputStream()}. If the socket mode has not yet been detected, the first read from the + * returned input stream will trigger mode detection, which is a potentially blocking operation. This means + * the accept() thread should avoid reading from this input stream if possible. + */ + @Override + public InputStream getInputStream() throws IOException { + return new UnifiedInputStream(this); + } + + /** + * See {@link Socket#getOutputStream()}. If the socket mode has not yet been detected, the first read from the + * returned input stream will trigger mode detection, which is a potentially blocking operation. This means + * the accept() thread should avoid reading from this input stream if possible. + */ + @Override + public OutputStream getOutputStream() throws IOException { + return new UnifiedOutputStream(this); + } + + /** + * See {@link Socket#setTcpNoDelay(boolean)}. Calling this method does not trigger mode detection. + */ + @Override + public void setTcpNoDelay(boolean on) throws SocketException { + getSocketAllowUnknownMode().setTcpNoDelay(on); + } + + /** + * See {@link Socket#getTcpNoDelay()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean getTcpNoDelay() throws SocketException { + return getSocketAllowUnknownMode().getTcpNoDelay(); + } + + /** + * See {@link Socket#setSoLinger(boolean, int)}. Calling this method does not trigger mode detection. + */ + @Override + public void setSoLinger(boolean on, int linger) throws SocketException { + getSocketAllowUnknownMode().setSoLinger(on, linger); + } + + /** + * See {@link Socket#getSoLinger()}. Calling this method does not trigger mode detection. + */ + @Override + public int getSoLinger() throws SocketException { + return getSocketAllowUnknownMode().getSoLinger(); + } + + /** + * See {@link Socket#sendUrgentData(int)}. Calling this method triggers mode detection, which is a potentially + * blocking operation, so it should not be done in the accept() thread. + */ + @Override + public void sendUrgentData(int data) throws IOException { + getSocket().sendUrgentData(data); + } + + /** + * See {@link Socket#setOOBInline(boolean)}. Calling this method does not trigger mode detection. + */ + @Override + public void setOOBInline(boolean on) throws SocketException { + getSocketAllowUnknownMode().setOOBInline(on); + } + + /** + * See {@link Socket#getOOBInline()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean getOOBInline() throws SocketException { + return getSocketAllowUnknownMode().getOOBInline(); + } + + /** + * See {@link Socket#setSoTimeout(int)}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized void setSoTimeout(int timeout) throws SocketException { + getSocketAllowUnknownMode().setSoTimeout(timeout); + } + + /** + * See {@link Socket#getSoTimeout()}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized int getSoTimeout() throws SocketException { + return getSocketAllowUnknownMode().getSoTimeout(); + } + + /** + * See {@link Socket#setSendBufferSize(int)}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized void setSendBufferSize(int size) throws SocketException { + getSocketAllowUnknownMode().setSendBufferSize(size); + } + + /** + * See {@link Socket#getSendBufferSize()}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized int getSendBufferSize() throws SocketException { + return getSocketAllowUnknownMode().getSendBufferSize(); + } + + /** + * See {@link Socket#setReceiveBufferSize(int)}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized void setReceiveBufferSize(int size) throws SocketException { + getSocketAllowUnknownMode().setReceiveBufferSize(size); + } + + /** + * See {@link Socket#getReceiveBufferSize()}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized int getReceiveBufferSize() throws SocketException { + return getSocketAllowUnknownMode().getReceiveBufferSize(); + } + + /** + * See {@link Socket#setKeepAlive(boolean)}. Calling this method does not trigger mode detection. + */ + @Override + public void setKeepAlive(boolean on) throws SocketException { + getSocketAllowUnknownMode().setKeepAlive(on); + } + + /** + * See {@link Socket#getKeepAlive()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean getKeepAlive() throws SocketException { + return getSocketAllowUnknownMode().getKeepAlive(); + } + + /** + * See {@link Socket#setTrafficClass(int)}. Calling this method does not trigger mode detection. + */ + @Override + public void setTrafficClass(int tc) throws SocketException { + getSocketAllowUnknownMode().setTrafficClass(tc); + } + + /** + * See {@link Socket#getTrafficClass()}. Calling this method does not trigger mode detection. + */ + @Override + public int getTrafficClass() throws SocketException { + return getSocketAllowUnknownMode().getTrafficClass(); + } + + /** + * See {@link Socket#setReuseAddress(boolean)}. Calling this method does not trigger mode detection. + */ + @Override + public void setReuseAddress(boolean on) throws SocketException { + getSocketAllowUnknownMode().setReuseAddress(on); + } + + /** + * See {@link Socket#getReuseAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean getReuseAddress() throws SocketException { + return getSocketAllowUnknownMode().getReuseAddress(); + } + + /** + * See {@link Socket#close()}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized void close() throws IOException { + getSocketAllowUnknownMode().close(); + } + + /** + * See {@link Socket#shutdownInput()}. Calling this method does not trigger mode detection. + */ + @Override + public void shutdownInput() throws IOException { + getSocketAllowUnknownMode().shutdownInput(); + } + + /** + * See {@link Socket#shutdownOutput()}. Calling this method does not trigger mode detection. + */ + @Override + public void shutdownOutput() throws IOException { + getSocketAllowUnknownMode().shutdownOutput(); + } + + /** + * See {@link Socket#toString()}. Calling this method does not trigger mode detection. + */ + @Override + public String toString() { + return "UnifiedSocket[mode=" + mode.toString() + "socket=" + getSocketAllowUnknownMode().toString() + "]"; + } + + /** + * See {@link Socket#isConnected()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isConnected() { + return getSocketAllowUnknownMode().isConnected(); + } + + /** + * See {@link Socket#isBound()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isBound() { + return getSocketAllowUnknownMode().isBound(); + } + + /** + * See {@link Socket#isClosed()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isClosed() { + return getSocketAllowUnknownMode().isClosed(); + } + + /** + * See {@link Socket#isInputShutdown()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isInputShutdown() { + return getSocketAllowUnknownMode().isInputShutdown(); + } + + /** + * See {@link Socket#isOutputShutdown()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isOutputShutdown() { + return getSocketAllowUnknownMode().isOutputShutdown(); + } + + /** + * See {@link Socket#setPerformancePreferences(int, int, int)}. Calling this method does not trigger + * mode detection. + */ + @Override + public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) { + getSocketAllowUnknownMode().setPerformancePreferences(connectionTime, latency, bandwidth); + } + } + + /** + * An input stream for a UnifiedSocket. The first read from this stream will trigger mode detection on the + * underlying UnifiedSocket. + */ + private static class UnifiedInputStream extends InputStream { + private final UnifiedSocket unifiedSocket; + private InputStream realInputStream; + + private UnifiedInputStream(UnifiedSocket unifiedSocket) { + this.unifiedSocket = unifiedSocket; + this.realInputStream = null; + } + + @Override + public int read() throws IOException { + return getRealInputStream().read(); + } + + /** + * Note: SocketInputStream has optimized implementations of bulk-read operations, so we need to call them + * directly instead of relying on the base-class implementation which just calls the single-byte read() over + * and over. Not implementing these results in awful performance. + */ + @Override + public int read(byte[] b) throws IOException { + return getRealInputStream().read(b); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return getRealInputStream().read(b, off, len); + } + + private InputStream getRealInputStream() throws IOException { + if (realInputStream == null) { + // Note: The first call to getSocket() triggers mode detection which can block + realInputStream = unifiedSocket.getSocket().getInputStream(); + } + return realInputStream; + } + + @Override + public long skip(long n) throws IOException { + return getRealInputStream().skip(n); + } + + @Override + public int available() throws IOException { + return getRealInputStream().available(); + } + + @Override + public void close() throws IOException { + getRealInputStream().close(); + } + + @Override + public synchronized void mark(int readlimit) { + try { + getRealInputStream().mark(readlimit); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public synchronized void reset() throws IOException { + getRealInputStream().reset(); + } + + @Override + public boolean markSupported() { + try { + return getRealInputStream().markSupported(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + } + + private static class UnifiedOutputStream extends OutputStream { + private final UnifiedSocket unifiedSocket; + private OutputStream realOutputStream; + + private UnifiedOutputStream(UnifiedSocket unifiedSocket) { + this.unifiedSocket = unifiedSocket; + this.realOutputStream = null; + } + + @Override + public void write(int b) throws IOException { + getRealOutputStream().write(b); + } + + @Override + public void write(byte[] b) throws IOException { + getRealOutputStream().write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + getRealOutputStream().write(b, off, len); + } + + @Override + public void flush() throws IOException { + getRealOutputStream().flush(); + } + + @Override + public void close() throws IOException { + getRealOutputStream().close(); + } + + private OutputStream getRealOutputStream() throws IOException { + if (realOutputStream == null) { + // Note: The first call to getSocket() triggers mode detection which can block + realOutputStream = unifiedSocket.getSocket().getOutputStream(); + } + return realOutputStream; + } + } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java index 6b343c3..546cf55 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java @@ -356,6 +356,34 @@ public class X509UtilTest extends BaseX509ParameterizedTestCase { true); } + @Test + public void testGetSslHandshakeDetectionTimeoutMillisProperty() { + X509Util x509Util = new ClientX509Util(); + Assert.assertEquals( + X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS, + x509Util.getSslHandshakeTimeoutMillis()); + try { + String newPropertyString = Integer.toString(X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS + 1); + System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), newPropertyString); + // Note: need to create a new ClientX509Util to pick up modified property value + Assert.assertEquals( + X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS + 1, + new ClientX509Util().getSslHandshakeTimeoutMillis()); + // 0 value not allowed, will return the default + System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "0"); + Assert.assertEquals( + X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS, + new ClientX509Util().getSslHandshakeTimeoutMillis()); + // Negative value not allowed, will return the default + System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "-1"); + Assert.assertEquals( + X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS, + new ClientX509Util().getSslHandshakeTimeoutMillis()); + } finally { + System.clearProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty()); + } + } + // Warning: this will reset the x509Util private void setCustomCipherSuites() { System.setProperty(x509Util.getCipherSuitesProperty(), customCipherSuites[0] + "," + customCipherSuites[1]); http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java index b088f47..67c15ad 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java @@ -80,7 +80,6 @@ import org.bouncycastle.util.io.pem.PemWriter; import org.junit.After; import org.junit.Assert; import org.junit.Before; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; @@ -442,7 +441,6 @@ public class QuorumSSLTest extends QuorumPeerTestBase { Assert.assertFalse(ClientBase.waitForServerUp("127.0.0.1:" + clientPortQp3, CONNECTION_TIMEOUT)); } - @Ignore("portUnification is currently broken and disabled") @Test public void testRollingUpgrade() throws Exception { // Form a quorum without ssl http://git-wip-us.apache.org/repos/asf/zookeeper/blob/64104eae/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java new file mode 100644 index 0000000..61862a4 --- /dev/null +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketModeDetectionTest.java @@ -0,0 +1,404 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zookeeper.server.quorum; + +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketOptions; +import java.security.Security; +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import org.apache.commons.io.FileUtils; +import org.apache.zookeeper.PortAssignment; +import org.apache.zookeeper.ZKTestCase; +import org.apache.zookeeper.common.ClientX509Util; +import org.apache.zookeeper.common.KeyStoreFileType; +import org.apache.zookeeper.common.X509KeyType; +import org.apache.zookeeper.common.X509TestContext; +import org.apache.zookeeper.common.X509Util; +import org.apache.zookeeper.test.ClientBase; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This test makes sure that certain operations on a UnifiedServerSocket do not + * trigger blocking mode detection. This is necessary to ensure that the + * Leader's accept() thread doesn't get blocked. + */ +@RunWith(Parameterized.class) +public class UnifiedServerSocketModeDetectionTest extends ZKTestCase { + private static final Logger LOG = LoggerFactory.getLogger( + UnifiedServerSocketModeDetectionTest.class); + + @Parameterized.Parameters + public static Collection<Object[]> params() { + ArrayList<Object[]> result = new ArrayList<>(); + result.add(new Object[] { true }); + result.add(new Object[] { false }); + return result; + } + + private static File tempDir; + private static X509TestContext x509TestContext; + + private boolean useSecureClient; + private X509Util x509Util; + private UnifiedServerSocket listeningSocket; + private UnifiedServerSocket.UnifiedSocket serverSideSocket; + private Socket clientSocket; + private ExecutorService workerPool; + private int port; + private InetSocketAddress localServerAddress; + + @BeforeClass + public static void setUpClass() throws Exception { + Security.addProvider(new BouncyCastleProvider()); + tempDir = ClientBase.createEmptyTestDir(); + x509TestContext = X509TestContext.newBuilder() + .setTempDir(tempDir) + .setKeyStoreKeyType(X509KeyType.EC) + .setTrustStoreKeyType(X509KeyType.EC) + .build(); + } + + @AfterClass + public static void tearDownClass() { + try { + FileUtils.deleteDirectory(tempDir); + } catch (IOException e) { + // ignore + } + Security.removeProvider(BouncyCastleProvider.PROVIDER_NAME); + } + + private static void forceClose(Socket s) { + if (s == null || s.isClosed()) { + return; + } + try { + s.close(); + } catch (IOException e) { + } + } + + private static void forceClose(ServerSocket s) { + if (s == null || s.isClosed()) { + return; + } + try { + s.close(); + } catch (IOException e) { + } + } + + public UnifiedServerSocketModeDetectionTest(Boolean useSecureClient) { + this.useSecureClient = useSecureClient; + } + + @Before + public void setUp() throws Exception { + x509Util = new ClientX509Util(); + x509TestContext.setSystemProperties(x509Util, KeyStoreFileType.JKS, KeyStoreFileType.JKS); + System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "100"); + workerPool = Executors.newCachedThreadPool(); + port = PortAssignment.unique(); + localServerAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), port); + listeningSocket = new UnifiedServerSocket(x509Util, true); + listeningSocket.bind(localServerAddress); + Future<UnifiedServerSocket.UnifiedSocket> acceptFuture; + acceptFuture = workerPool.submit(new Callable<UnifiedServerSocket.UnifiedSocket>() { + @Override + public UnifiedServerSocket.UnifiedSocket call() throws Exception { + try { + return (UnifiedServerSocket.UnifiedSocket) listeningSocket.accept(); + } catch (IOException e) { + LOG.error("Error in accept(): ", e); + throw e; + } + } + }); + if (useSecureClient) { + clientSocket = x509Util.createSSLSocket(); + clientSocket.connect(localServerAddress); + } else { + clientSocket = new Socket(); + clientSocket.connect(localServerAddress); + clientSocket.getOutputStream().write(new byte[] { 1, 2, 3, 4, 5 }); + } + serverSideSocket = acceptFuture.get(); + } + + @After + public void tearDown() throws Exception { + x509TestContext.clearSystemProperties(x509Util); + System.clearProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty()); + forceClose(listeningSocket); + forceClose(serverSideSocket); + forceClose(clientSocket); + workerPool.shutdown(); + workerPool.awaitTermination(1000, TimeUnit.MILLISECONDS); + } + + @Test + public void testGetInetAddress() { + serverSideSocket.getInetAddress(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testGetLocalAddress() { + serverSideSocket.getLocalAddress(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testGetPort() { + serverSideSocket.getPort(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testGetLocalPort() { + serverSideSocket.getLocalPort(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testGetRemoteSocketAddress() { + serverSideSocket.getRemoteSocketAddress(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testGetLocalSocketAddress() { + serverSideSocket.getLocalSocketAddress(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testGetInputStream() throws IOException { + serverSideSocket.getInputStream(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testGetOutputStream() throws IOException { + serverSideSocket.getOutputStream(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testGetTcpNoDelay() throws IOException { + serverSideSocket.getTcpNoDelay(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testSetTcpNoDelay() throws IOException { + boolean tcpNoDelay = serverSideSocket.getTcpNoDelay(); + tcpNoDelay = !tcpNoDelay; + serverSideSocket.setTcpNoDelay(tcpNoDelay); + Assert.assertFalse(serverSideSocket.isModeKnown()); + Assert.assertEquals(tcpNoDelay, serverSideSocket.getTcpNoDelay()); + } + + @Test + public void testGetSoLinger() throws IOException { + serverSideSocket.getSoLinger(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testSetSoLinger() throws IOException { + int soLinger = serverSideSocket.getSoLinger(); + if (soLinger == -1) { + // enable it if disabled + serverSideSocket.setSoLinger(true, 1); + Assert.assertFalse(serverSideSocket.isModeKnown()); + Assert.assertEquals(1, serverSideSocket.getSoLinger()); + } else { + // disable it if enabled + serverSideSocket.setSoLinger(false, -1); + Assert.assertFalse(serverSideSocket.isModeKnown()); + Assert.assertEquals(-1, serverSideSocket.getSoLinger()); + } + } + + @Test + public void testGetSoTimeout() throws IOException { + serverSideSocket.getSoTimeout(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testSetSoTimeout() throws IOException { + int timeout = serverSideSocket.getSoTimeout(); + timeout = timeout + 10; + serverSideSocket.setSoTimeout(timeout); + Assert.assertFalse(serverSideSocket.isModeKnown()); + Assert.assertEquals(timeout, serverSideSocket.getSoTimeout()); + } + + @Test + public void testGetSendBufferSize() throws IOException { + serverSideSocket.getSendBufferSize(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testSetSendBufferSize() throws IOException { + serverSideSocket.setSendBufferSize(serverSideSocket.getSendBufferSize() + 1024); + Assert.assertFalse(serverSideSocket.isModeKnown()); + // Note: the new buffer size is a hint and socket implementation + // is free to ignore it, so we don't verify that we get back the + // same value. + + } + + @Test + public void testGetReceiveBufferSize() throws IOException { + serverSideSocket.getReceiveBufferSize(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testSetReceiveBufferSize() throws IOException { + serverSideSocket.setReceiveBufferSize(serverSideSocket.getReceiveBufferSize() + 1024); + Assert.assertFalse(serverSideSocket.isModeKnown()); + // Note: the new buffer size is a hint and socket implementation + // is free to ignore it, so we don't verify that we get back the + // same value. + + } + + @Test + public void testGetKeepAlive() throws IOException { + serverSideSocket.getKeepAlive(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testSetKeepAlive() throws IOException { + boolean keepAlive = serverSideSocket.getKeepAlive(); + keepAlive = !keepAlive; + serverSideSocket.setKeepAlive(keepAlive); + Assert.assertFalse(serverSideSocket.isModeKnown()); + Assert.assertEquals(keepAlive, serverSideSocket.getKeepAlive()); + } + + @Test + public void testGetTrafficClass() throws IOException { + serverSideSocket.getTrafficClass(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testSetTrafficClass() throws IOException { + serverSideSocket.setTrafficClass(SocketOptions.IP_TOS); + Assert.assertFalse(serverSideSocket.isModeKnown()); + // Note: according to the Socket javadocs, setTrafficClass() may be + // ignored by socket implementations, so we don't check that the value + // we set is returned. + } + + @Test + public void testGetReuseAddress() throws IOException { + serverSideSocket.getReuseAddress(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testSetReuseAddress() throws IOException { + boolean reuseAddress = serverSideSocket.getReuseAddress(); + reuseAddress = !reuseAddress; + serverSideSocket.setReuseAddress(reuseAddress); + Assert.assertFalse(serverSideSocket.isModeKnown()); + Assert.assertEquals(reuseAddress, serverSideSocket.getReuseAddress()); + } + + @Test + public void testClose() throws IOException { + serverSideSocket.close(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testShutdownInput() throws IOException { + serverSideSocket.shutdownInput(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testShutdownOutput() throws IOException { + serverSideSocket.shutdownOutput(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testIsConnected() { + serverSideSocket.isConnected(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testIsBound() { + serverSideSocket.isBound(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testIsClosed() { + serverSideSocket.isClosed(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + } + + @Test + public void testIsInputShutdown() throws IOException { + serverSideSocket.isInputShutdown(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + serverSideSocket.shutdownInput(); + Assert.assertTrue(serverSideSocket.isInputShutdown()); + } + + @Test + public void testIsOutputShutdown() throws IOException { + serverSideSocket.isOutputShutdown(); + Assert.assertFalse(serverSideSocket.isModeKnown()); + serverSideSocket.shutdownOutput(); + Assert.assertTrue(serverSideSocket.isOutputShutdown()); + } +}
