This is an automated email from the ASF dual-hosted git repository.

aleksey pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git

commit cd20f882820230980107ce5737c47a78a6fc4901
Merge: 0a9e129 aa317ae
Author: Aleksey Yeshchenko <[email protected]>
AuthorDate: Fri Jan 29 10:55:06 2021 +0000

    Merge branch 'cassandra-3.11' into trunk

 .../apache/cassandra/config/EncryptionOptions.java |  18 +++
 .../cassandra/net/InboundConnectionInitiator.java  |  18 +++
 .../cassandra/net/InboundMessageHandlers.java      |   8 ++
 .../test/InternodeEncryptionEnforcementTest.java   | 134 +++++++++++++++++++++
 4 files changed, 178 insertions(+)

diff --cc src/java/org/apache/cassandra/config/EncryptionOptions.java
index a534e36,1b1d8ce..93668d9
--- a/src/java/org/apache/cassandra/config/EncryptionOptions.java
+++ b/src/java/org/apache/cassandra/config/EncryptionOptions.java
@@@ -466,63 -54,13 +466,72 @@@ public class EncryptionOption
              all, none, dc, rack
          }
  
 -        public InternodeEncryption internode_encryption = 
InternodeEncryption.none;
 +        public final InternodeEncryption internode_encryption;
 +        public final boolean enable_legacy_ssl_storage_port;
  
 -        public boolean shouldEncrypt(InetAddress endpoint)
 +        public ServerEncryptionOptions()
          {
 -            IEndpointSnitch snitch = DatabaseDescriptor.getEndpointSnitch();
 -            InetAddress local = FBUtilities.getBroadcastAddress();
 +            this.internode_encryption = InternodeEncryption.none;
 +            this.enable_legacy_ssl_storage_port = false;
 +        }
 +
 +        public ServerEncryptionOptions(String keystore, String 
keystore_password, String truststore,
 +                                       String truststore_password, 
List<String> cipher_suites, String protocol,
 +                                       List<String> accepted_protocols, 
String algorithm, String store_type,
 +                                       boolean require_client_auth, boolean 
require_endpoint_verification,
 +                                       Boolean optional, InternodeEncryption 
internode_encryption,
 +                                       boolean enable_legacy_ssl_storage_port)
 +        {
 +            super(keystore, keystore_password, truststore, 
truststore_password, cipher_suites, protocol,
 +                  accepted_protocols, algorithm, store_type, 
require_client_auth, require_endpoint_verification,
 +                  null, optional);
 +            this.internode_encryption = internode_encryption;
 +            this.enable_legacy_ssl_storage_port = 
enable_legacy_ssl_storage_port;
 +        }
 +
 +        public ServerEncryptionOptions(ServerEncryptionOptions options)
 +        {
 +            super(options);
 +            this.internode_encryption = options.internode_encryption;
 +            this.enable_legacy_ssl_storage_port = 
options.enable_legacy_ssl_storage_port;
 +        }
 +
 +        @Override
 +        public EncryptionOptions applyConfig()
 +        {
 +            return applyConfigInternal();
 +        }
  
 +        private ServerEncryptionOptions applyConfigInternal()
 +        {
 +            super.applyConfig();
 +
 +            isEnabled = this.internode_encryption != InternodeEncryption.none;
 +
 +            if (this.enabled != null && this.enabled && !isEnabled)
 +            {
 +                logger.warn("Setting server_encryption_options.enabled has no 
effect, use internode_encryption");
 +            }
 +
++            if (require_client_auth && (internode_encryption == 
InternodeEncryption.rack || internode_encryption == InternodeEncryption.dc))
++            {
++                logger.warn("Setting require_client_auth is incompatible with 
'rack' and 'dc' internode_encryption values."
++                          + " It is possible for an internode connection to 
pretend to be in the same rack/dc by spoofing"
++                          + " its broadcast address in the handshake and 
bypass authentication. To ensure that mutual TLS"
++                          + " authentication is not bypassed, please set 
internode_encryption to 'all'. Continuing with"
++                          + " insecure configuration.");
++            }
++
 +            // regardless of the optional flag, if the internode encryption 
is set to rack or dc
 +            // it must be optional so that unencrypted connections within the 
rack or dc can be established.
 +            isOptional = super.isOptional || internode_encryption == 
InternodeEncryption.rack || internode_encryption == InternodeEncryption.dc;
 +
 +            return this;
 +        }
 +
 +        public boolean shouldEncrypt(InetAddressAndPort endpoint)
 +        {
 +            IEndpointSnitch snitch = DatabaseDescriptor.getEndpointSnitch();
              switch (internode_encryption)
              {
                  case none:
@@@ -543,112 -81,16 +552,121 @@@
              return true;
          }
  
 -        public void validate()
++        /**
++         * {@link #isOptional} will be set to {@code true} implicitly for 
{@code internode_encryption}
++         * values of "dc" and "all". This method returns the explicit, raw 
value of {@link #optional}
++         * as set by the user (if set at all).
++         */
++        public boolean isExplicitlyOptional()
+         {
 -            if (require_client_auth && (internode_encryption == 
InternodeEncryption.rack || internode_encryption == InternodeEncryption.dc))
 -            {
 -                logger.warn("Setting require_client_auth is incompatible with 
'rack' and 'dc' internode_encryption values."
 -                          + " It is possible for an internode connection to 
pretend to be in the same rack/dc by spoofing"
 -                          + " its broadcast address in the handshake and 
bypass authentication. To ensure that mutual TLS"
 -                          + " authentication is not bypassed, please set 
internode_encryption to 'all'. Continuing with"
 -                          + " insecure configuration.");
 -            }
++            return optional != null && optional;
++        }
 +
 +        public ServerEncryptionOptions withKeyStore(String keystore)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withKeyStorePassword(String 
keystore_password)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withTrustStore(String truststore)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withTrustStorePassword(String 
truststore_password)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withCipherSuites(List<String> 
cipher_suites)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withCipherSuites(String ... 
cipher_suites)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, ImmutableList.copyOf(cipher_suites),
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withProtocol(String protocol)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withAcceptedProtocols(List<String> 
accepted_protocols)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols 
== null ? null : ImmutableList.copyOf(accepted_protocols),
 +                                               algorithm, store_type, 
require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withAlgorithm(String algorithm)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withStoreType(String store_type)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withRequireClientAuth(boolean 
require_client_auth)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
          }
 +
 +        public ServerEncryptionOptions 
withRequireEndpointVerification(boolean require_endpoint_verification)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withOptional(boolean optional)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions 
withInternodeEncryption(InternodeEncryption internode_encryption)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
 +        public ServerEncryptionOptions withLegacySslStoragePort(boolean 
enable_legacy_ssl_storage_port)
 +        {
 +            return new ServerEncryptionOptions(keystore, keystore_password, 
truststore, truststore_password, cipher_suites,
 +                                               protocol, accepted_protocols, 
algorithm, store_type, require_client_auth, require_endpoint_verification,
 +                                               optional, 
internode_encryption, enable_legacy_ssl_storage_port).applyConfigInternal();
 +        }
 +
      }
  }
diff --cc src/java/org/apache/cassandra/net/InboundConnectionInitiator.java
index d21358a,0000000..4c31adf
mode 100644,000000..100644
--- a/src/java/org/apache/cassandra/net/InboundConnectionInitiator.java
+++ b/src/java/org/apache/cassandra/net/InboundConnectionInitiator.java
@@@ -1,527 -1,0 +1,545 @@@
 +/*
 + * Licensed to the Apache Software Foundation (ASF) under one
 + * or more contributor license agreements.  See the NOTICE file
 + * distributed with this work for additional information
 + * regarding copyright ownership.  The ASF licenses this file
 + * to you under the Apache License, Version 2.0 (the
 + * "License"); you may not use this file except in compliance
 + * with the License.  You may obtain a copy of the License at
 + *
 + *     http://www.apache.org/licenses/LICENSE-2.0
 + *
 + * Unless required by applicable law or agreed to in writing, software
 + * distributed under the License is distributed on an "AS IS" BASIS,
 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 + * See the License for the specific language governing permissions and
 + * limitations under the License.
 + */
 +package org.apache.cassandra.net;
 +
 +import java.io.IOException;
 +import java.net.InetSocketAddress;
 +import java.net.SocketAddress;
 +import java.util.List;
 +import java.util.concurrent.Future;
 +import java.util.function.Consumer;
 +
 +import com.google.common.annotations.VisibleForTesting;
 +import org.slf4j.Logger;
 +import org.slf4j.LoggerFactory;
 +
 +import io.netty.bootstrap.ServerBootstrap;
 +import io.netty.buffer.ByteBuf;
 +import io.netty.channel.Channel;
 +import io.netty.channel.ChannelFuture;
 +import io.netty.channel.ChannelFutureListener;
 +import io.netty.channel.ChannelHandlerContext;
 +import io.netty.channel.ChannelInitializer;
 +import io.netty.channel.ChannelOption;
 +import io.netty.channel.ChannelPipeline;
 +import io.netty.channel.group.ChannelGroup;
 +import io.netty.channel.socket.SocketChannel;
 +import io.netty.handler.codec.ByteToMessageDecoder;
 +import io.netty.handler.logging.LogLevel;
 +import io.netty.handler.logging.LoggingHandler;
 +import io.netty.handler.ssl.SslContext;
 +import io.netty.handler.ssl.SslHandler;
 +import org.apache.cassandra.config.EncryptionOptions;
 +import org.apache.cassandra.exceptions.ConfigurationException;
 +import org.apache.cassandra.locator.InetAddressAndPort;
 +import org.apache.cassandra.net.OutboundConnectionSettings.Framing;
 +import org.apache.cassandra.security.SSLFactory;
 +import org.apache.cassandra.streaming.async.StreamingInboundHandler;
 +import org.apache.cassandra.utils.memory.BufferPools;
 +
 +import static java.lang.Math.*;
 +import static java.util.concurrent.TimeUnit.MILLISECONDS;
 +import static org.apache.cassandra.net.MessagingService.*;
 +import static org.apache.cassandra.net.MessagingService.VERSION_40;
 +import static org.apache.cassandra.net.MessagingService.current_version;
 +import static org.apache.cassandra.net.MessagingService.minimum_version;
 +import static org.apache.cassandra.net.SocketFactory.WIRETRACE;
 +import static org.apache.cassandra.net.SocketFactory.newSslHandler;
 +
 +public class InboundConnectionInitiator
 +{
 +    private static final Logger logger = 
LoggerFactory.getLogger(InboundConnectionInitiator.class);
 +
 +    private static class Initializer extends ChannelInitializer<SocketChannel>
 +    {
 +        private final InboundConnectionSettings settings;
 +        private final ChannelGroup channelGroup;
 +        private final Consumer<ChannelPipeline> pipelineInjector;
 +
 +        Initializer(InboundConnectionSettings settings, ChannelGroup 
channelGroup,
 +                    Consumer<ChannelPipeline> pipelineInjector)
 +        {
 +            this.settings = settings;
 +            this.channelGroup = channelGroup;
 +            this.pipelineInjector = pipelineInjector;
 +        }
 +
 +        @Override
 +        public void initChannel(SocketChannel channel) throws Exception
 +        {
 +            channelGroup.add(channel);
 +
 +            channel.config().setOption(ChannelOption.ALLOCATOR, 
GlobalBufferPoolAllocator.instance);
 +            channel.config().setOption(ChannelOption.SO_KEEPALIVE, true);
 +            channel.config().setOption(ChannelOption.SO_REUSEADDR, true);
 +            channel.config().setOption(ChannelOption.TCP_NODELAY, true); // 
we only send handshake messages; no point ever delaying
 +
 +            ChannelPipeline pipeline = channel.pipeline();
 +
 +            pipelineInjector.accept(pipeline);
 +
 +            // order of handlers: ssl -> logger -> handshakeHandler
 +            // For either unencrypted or transitional modes, allow Ssl 
optionally.
 +            switch(settings.encryption.tlsEncryptionPolicy())
 +            {
 +                case UNENCRYPTED:
 +                    // Handler checks for SSL connection attempts and cleanly 
rejects them if encryption is disabled
 +                    pipeline.addFirst("rejectssl", new RejectSslHandler());
 +                    break;
 +                case OPTIONAL:
 +                    pipeline.addFirst("ssl", new 
OptionalSslHandler(settings.encryption));
 +                    break;
 +                case ENCRYPTED:
 +                    SslHandler sslHandler = getSslHandler("creating", 
channel, settings.encryption);
 +                    pipeline.addFirst("ssl", sslHandler);
 +                    break;
 +            }
 +
 +            if (WIRETRACE)
 +                pipeline.addLast("logger", new LoggingHandler(LogLevel.INFO));
 +
 +            channel.pipeline().addLast("handshake", new Handler(settings));
 +
 +        }
 +    }
 +
 +    /**
 +     * Create a {@link Channel} that listens on the {@code localAddr}. This 
method will block while trying to bind to the address,
 +     * but it does not make a remote call.
 +     */
 +    private static ChannelFuture bind(Initializer initializer) throws 
ConfigurationException
 +    {
 +        logger.info("Listening on {}", initializer.settings);
 +
 +        ServerBootstrap bootstrap = initializer.settings.socketFactory
 +                                    .newServerBootstrap()
 +                                    .option(ChannelOption.SO_BACKLOG, 1 << 9)
 +                                    .option(ChannelOption.ALLOCATOR, 
GlobalBufferPoolAllocator.instance)
 +                                    .option(ChannelOption.SO_REUSEADDR, true)
 +                                    .childHandler(initializer);
 +
 +        int socketReceiveBufferSizeInBytes = 
initializer.settings.socketReceiveBufferSizeInBytes;
 +        if (socketReceiveBufferSizeInBytes > 0)
 +            bootstrap.childOption(ChannelOption.SO_RCVBUF, 
socketReceiveBufferSizeInBytes);
 +
 +        InetAddressAndPort bind = initializer.settings.bindAddress;
 +        ChannelFuture channelFuture = bootstrap.bind(new 
InetSocketAddress(bind.address, bind.port));
 +
 +        if (!channelFuture.awaitUninterruptibly().isSuccess())
 +        {
 +            if (channelFuture.channel().isOpen())
 +                channelFuture.channel().close();
 +
 +            Throwable failedChannelCause = channelFuture.cause();
 +
 +            String causeString = "";
 +            if (failedChannelCause != null && failedChannelCause.getMessage() 
!= null)
 +                causeString = failedChannelCause.getMessage();
 +
 +            if (causeString.contains("in use"))
 +            {
 +                throw new ConfigurationException(bind + " is in use by 
another process.  Change listen_address:storage_port " +
 +                                                 "in cassandra.yaml to values 
that do not conflict with other services");
 +            }
 +            // looking at the jdk source, solaris/windows bind failue 
messages both use the phrase "cannot assign requested address".
 +            // windows message uses "Cannot" (with a capital 'C'), and 
solaris (a/k/a *nux) doe not. hence we search for "annot" <sigh>
 +            else if (causeString.contains("annot assign requested address"))
 +            {
 +                throw new ConfigurationException("Unable to bind to address " 
+ bind
 +                                                 + ". Set listen_address in 
cassandra.yaml to an interface you can bind to, e.g., your private IP address 
on EC2");
 +            }
 +            else
 +            {
 +                throw new ConfigurationException("failed to bind to: " + 
bind, failedChannelCause);
 +            }
 +        }
 +
 +        return channelFuture;
 +    }
 +
 +    public static ChannelFuture bind(InboundConnectionSettings settings, 
ChannelGroup channelGroup,
 +                                     Consumer<ChannelPipeline> 
pipelineInjector)
 +    {
 +        return bind(new Initializer(settings, channelGroup, 
pipelineInjector));
 +    }
 +
 +    /**
 +     * 'Server-side' component that negotiates the internode handshake when 
establishing a new connection.
 +     * This handler will be the first in the netty channel for each incoming 
connection (secure socket (TLS) notwithstanding),
 +     * and once the handshake is successful, it will configure the proper 
handlers ({@link InboundMessageHandler}
 +     * or {@link StreamingInboundHandler}) and remove itself from the working 
pipeline.
 +     */
 +    static class Handler extends ByteToMessageDecoder
 +    {
 +        private final InboundConnectionSettings settings;
 +
 +        private HandshakeProtocol.Initiate initiate;
 +        private HandshakeProtocol.ConfirmOutboundPre40 confirmOutboundPre40;
 +
 +        /**
 +         * A future the essentially places a timeout on how long we'll wait 
for the peer
 +         * to complete the next step of the handshake.
 +         */
 +        private Future<?> handshakeTimeout;
 +
 +        Handler(InboundConnectionSettings settings)
 +        {
 +            this.settings = settings;
 +        }
 +
 +        /**
 +         * On registration, immediately schedule a timeout to kill this 
connection if it does not handshake promptly,
 +         * and authenticate the remote address.
 +         */
 +        public void handlerAdded(ChannelHandlerContext ctx) throws Exception
 +        {
 +            handshakeTimeout = ctx.executor().schedule(() -> {
 +                logger.error("Timeout handshaking with {} (on {})", 
SocketFactory.addressId(initiate.from, (InetSocketAddress) 
ctx.channel().remoteAddress()), settings.bindAddress);
 +                failHandshake(ctx);
 +            }, HandshakeProtocol.TIMEOUT_MILLIS, MILLISECONDS);
 +
 +            authenticate(ctx.channel().remoteAddress());
 +        }
 +
 +        private void authenticate(SocketAddress socketAddress) throws 
IOException
 +        {
 +            if 
(socketAddress.getClass().getSimpleName().equals("EmbeddedSocketAddress"))
 +                return;
 +
 +            if (!(socketAddress instanceof InetSocketAddress))
 +                throw new IOException(String.format("Unexpected SocketAddress 
type: %s, %s", socketAddress.getClass(), socketAddress));
 +
 +            InetSocketAddress addr = (InetSocketAddress)socketAddress;
 +            if (!settings.authenticate(addr.getAddress(), addr.getPort()))
 +                throw new IOException("Authentication failure for inbound 
connection from peer " + addr);
 +        }
 +
 +        @Override
 +        protected void decode(ChannelHandlerContext ctx, ByteBuf in, 
List<Object> out) throws Exception
 +        {
 +            if (initiate == null) initiate(ctx, in);
 +            else if (initiate.acceptVersions == null && confirmOutboundPre40 
== null) confirmPre40(ctx, in);
 +            else throw new IllegalStateException("Should no longer be on 
pipeline");
 +        }
 +
 +        void initiate(ChannelHandlerContext ctx, ByteBuf in) throws 
IOException
 +        {
 +            initiate = HandshakeProtocol.Initiate.maybeDecode(in);
 +            if (initiate == null)
 +                return;
 +
 +            logger.trace("Received handshake initiation message from peer {}, 
message = {}", ctx.channel().remoteAddress(), initiate);
++
++            if (isEncryptionRequired(initiate.from) && 
!isChannelEncrypted(ctx))
++            {
++                logger.warn("peer {} attempted to establish an unencrypted 
connection (broadcast address {})",
++                            ctx.channel().remoteAddress(), initiate.from);
++                failHandshake(ctx);
++            }
++
 +            if (initiate.acceptVersions != null)
 +            {
 +                logger.trace("Connection version {} (min {}) from {}", 
initiate.acceptVersions.max, initiate.acceptVersions.min, initiate.from);
 +
 +                final AcceptVersions accept;
 +
 +                if (initiate.type.isStreaming())
 +                    accept = settings.acceptStreaming;
 +                else
 +                    accept = settings.acceptMessaging;
 +
 +                int useMessagingVersion = max(accept.min, min(accept.max, 
initiate.acceptVersions.max));
 +                ByteBuf flush = new 
HandshakeProtocol.Accept(useMessagingVersion, accept.max).encode(ctx.alloc());
 +
 +                AsyncChannelPromise.writeAndFlush(ctx, flush, 
(ChannelFutureListener) future -> {
 +                    if (!future.isSuccess())
 +                        exceptionCaught(future.channel(), future.cause());
 +                });
 +
 +                if (initiate.acceptVersions.min > accept.max)
 +                {
 +                    logger.info("peer {} only supports messaging versions 
higher ({}) than this node supports ({})", ctx.channel().remoteAddress(), 
initiate.acceptVersions.min, current_version);
 +                    failHandshake(ctx);
 +                }
 +                else if (initiate.acceptVersions.max < accept.min)
 +                {
 +                    logger.info("peer {} only supports messaging versions 
lower ({}) than this node supports ({})", ctx.channel().remoteAddress(), 
initiate.acceptVersions.max, minimum_version);
 +                    failHandshake(ctx);
 +                }
 +                else
 +                {
 +                    if (initiate.type.isStreaming())
 +                        setupStreamingPipeline(initiate.from, ctx);
 +                    else
 +                        setupMessagingPipeline(initiate.from, 
useMessagingVersion, initiate.acceptVersions.max, ctx.pipeline());
 +                }
 +            }
 +            else
 +            {
 +                int version = initiate.requestMessagingVersion;
 +                assert version < VERSION_40 && version >= 
settings.acceptMessaging.min;
 +                logger.trace("Connection version {} from {}", version, 
ctx.channel().remoteAddress());
 +
 +                if (initiate.type.isStreaming())
 +                {
 +                    // streaming connections are per-session and have a fixed 
version.  we can't do anything with a wrong-version stream connection, so drop 
it.
 +                    if (version != settings.acceptStreaming.max)
 +                    {
 +                        logger.warn("Received stream using protocol version 
{} (my version {}). Terminating connection", version, 
settings.acceptStreaming.max);
 +                        failHandshake(ctx);
 +                    }
 +                    setupStreamingPipeline(initiate.from, ctx);
 +                }
 +                else
 +                {
 +                    // if this version is < the MS version the other node is 
trying
 +                    // to connect with, the other node will disconnect
 +                    ByteBuf response = 
HandshakeProtocol.Accept.respondPre40(settings.acceptMessaging.max, 
ctx.alloc());
 +                    AsyncChannelPromise.writeAndFlush(ctx, response,
 +                          (ChannelFutureListener) future -> {
 +                               if (!future.isSuccess())
 +                                   exceptionCaught(future.channel(), 
future.cause());
 +                    });
 +
 +                    if (version < VERSION_30)
 +                        throw new IOException(String.format("Unable to read 
obsolete message version %s from %s; The earliest version supported is 3.0.0", 
version, ctx.channel().remoteAddress()));
 +
 +                    // we don't setup the messaging pipeline here, as the 
legacy messaging handshake requires one more message to finish
 +                }
 +            }
 +        }
 +
++        private boolean isEncryptionRequired(InetAddressAndPort peer)
++        {
++            return !settings.encryption.isExplicitlyOptional() && 
settings.encryption.shouldEncrypt(peer);
++        }
++
++        private boolean isChannelEncrypted(ChannelHandlerContext ctx)
++        {
++            return ctx.pipeline().get(SslHandler.class) != null;
++        }
++
 +        /**
 +         * Handles the third (and last) message in the internode messaging 
handshake protocol for pre40 nodes.
 +         * Grabs the protocol version and IP addr the peer wants to use.
 +         */
 +        @VisibleForTesting
 +        void confirmPre40(ChannelHandlerContext ctx, ByteBuf in)
 +        {
 +            confirmOutboundPre40 = 
HandshakeProtocol.ConfirmOutboundPre40.maybeDecode(in);
 +            if (confirmOutboundPre40 == null)
 +                return;
 +
 +            logger.trace("Received third handshake message from peer {}, 
message = {}", ctx.channel().remoteAddress(), confirmOutboundPre40);
 +            setupMessagingPipeline(confirmOutboundPre40.from, 
initiate.requestMessagingVersion, confirmOutboundPre40.maxMessagingVersion, 
ctx.pipeline());
 +        }
 +
 +        @Override
 +        public void exceptionCaught(ChannelHandlerContext ctx, Throwable 
cause)
 +        {
 +            exceptionCaught(ctx.channel(), cause);
 +        }
 +
 +        private void exceptionCaught(Channel channel, Throwable cause)
 +        {
 +            logger.error("Failed to properly handshake with peer {}. Closing 
the channel.", channel.remoteAddress(), cause);
 +            try
 +            {
 +                failHandshake(channel);
 +            }
 +            catch (Throwable t)
 +            {
 +                logger.error("Unexpected exception in {}.exceptionCaught", 
this.getClass().getSimpleName(), t);
 +            }
 +        }
 +
 +        private void failHandshake(ChannelHandlerContext ctx)
 +        {
 +            failHandshake(ctx.channel());
 +        }
 +
 +        private void failHandshake(Channel channel)
 +        {
 +            channel.close();
 +            if (handshakeTimeout != null)
 +                handshakeTimeout.cancel(true);
 +        }
 +
 +        private void setupStreamingPipeline(InetAddressAndPort from, 
ChannelHandlerContext ctx)
 +        {
 +            handshakeTimeout.cancel(true);
 +            assert initiate.framing == Framing.UNPROTECTED;
 +
 +            ChannelPipeline pipeline = ctx.pipeline();
 +            Channel channel = ctx.channel();
 +
 +            if (from == null)
 +            {
 +                InetSocketAddress address = (InetSocketAddress) 
channel.remoteAddress();
 +                from = 
InetAddressAndPort.getByAddressOverrideDefaults(address.getAddress(), 
address.getPort());
 +            }
 +
 +            
BufferPools.forNetworking().setRecycleWhenFreeForCurrentThread(false);
 +            pipeline.replace(this, "streamInbound", new 
StreamingInboundHandler(from, current_version, null));
 +
 +            logger.info("{} streaming connection established, version = {}, 
framing = {}, encryption = {}",
 +                        SocketFactory.channelId(from,
 +                                                (InetSocketAddress) 
channel.remoteAddress(),
 +                                                settings.bindAddress,
 +                                                (InetSocketAddress) 
channel.localAddress(),
 +                                                ConnectionType.STREAMING,
 +                                                channel.id().asShortText()),
 +                        current_version,
 +                        initiate.framing,
 +                        
SocketFactory.encryptionConnectionSummary(pipeline.channel()));
 +        }
 +
 +        @VisibleForTesting
 +        void setupMessagingPipeline(InetAddressAndPort from, int 
useMessagingVersion, int maxMessagingVersion, ChannelPipeline pipeline)
 +        {
 +            handshakeTimeout.cancel(true);
 +            // record the "true" endpoint, i.e. the one the peer is 
identified with, as opposed to the socket it connected over
 +            instance().versions.set(from, maxMessagingVersion);
 +
 +            
BufferPools.forNetworking().setRecycleWhenFreeForCurrentThread(false);
 +            BufferPoolAllocator allocator = 
GlobalBufferPoolAllocator.instance;
 +            if (initiate.type == ConnectionType.LARGE_MESSAGES)
 +            {
 +                // for large messages, swap the global pool allocator for a 
local one, to optimise utilisation of chunks
 +                allocator = new 
LocalBufferPoolAllocator(pipeline.channel().eventLoop());
 +                pipeline.channel().config().setAllocator(allocator);
 +            }
 +
 +            FrameDecoder frameDecoder;
 +            switch (initiate.framing)
 +            {
 +                case LZ4:
 +                {
 +                    if (useMessagingVersion >= VERSION_40)
 +                        frameDecoder = FrameDecoderLZ4.fast(allocator);
 +                    else
 +                        frameDecoder = new FrameDecoderLegacyLZ4(allocator, 
useMessagingVersion);
 +                    break;
 +                }
 +                case CRC:
 +                {
 +                    if (useMessagingVersion >= VERSION_40)
 +                    {
 +                        frameDecoder = FrameDecoderCrc.create(allocator);
 +                        break;
 +                    }
 +                }
 +                case UNPROTECTED:
 +                {
 +                    if (useMessagingVersion >= VERSION_40)
 +                        frameDecoder = new FrameDecoderUnprotected(allocator);
 +                    else
 +                        frameDecoder = new FrameDecoderLegacy(allocator, 
useMessagingVersion);
 +                    break;
 +                }
 +                default:
 +                    throw new AssertionError();
 +            }
 +
 +            frameDecoder.addLastTo(pipeline);
 +
 +            InboundMessageHandler handler =
 +                settings.handlers.apply(from).createHandler(frameDecoder, 
initiate.type, pipeline.channel(), useMessagingVersion);
 +
 +            logger.info("{} messaging connection established, version = {}, 
framing = {}, encryption = {}",
 +                        handler.id(true),
 +                        useMessagingVersion,
 +                        initiate.framing,
 +                        
SocketFactory.encryptionConnectionSummary(pipeline.channel()));
 +
 +            pipeline.addLast("deserialize", handler);
 +
 +            pipeline.remove(this);
 +        }
 +    }
 +
 +    private static SslHandler getSslHandler(String description, Channel 
channel, EncryptionOptions.ServerEncryptionOptions encryptionOptions) throws 
IOException
 +    {
 +        final boolean buildTrustStore = true;
 +        SslContext sslContext = 
SSLFactory.getOrCreateSslContext(encryptionOptions, buildTrustStore, 
SSLFactory.SocketType.SERVER);
 +        InetSocketAddress peer = 
encryptionOptions.require_endpoint_verification ? (InetSocketAddress) 
channel.remoteAddress() : null;
 +        SslHandler sslHandler = newSslHandler(channel, sslContext, peer);
 +        logger.trace("{} inbound netty SslContext: context={}, engine={}", 
description, sslContext.getClass().getName(), 
sslHandler.engine().getClass().getName());
 +        return sslHandler;
 +    }
 +
 +    private static class OptionalSslHandler extends ByteToMessageDecoder
 +    {
 +        private final EncryptionOptions.ServerEncryptionOptions 
encryptionOptions;
 +
 +        OptionalSslHandler(EncryptionOptions.ServerEncryptionOptions 
encryptionOptions)
 +        {
 +            this.encryptionOptions = encryptionOptions;
 +        }
 +
 +        protected void decode(ChannelHandlerContext ctx, ByteBuf in, 
List<Object> out) throws Exception
 +        {
 +            if (in.readableBytes() < 5)
 +            {
 +                // To detect if SSL must be used we need to have at least 5 
bytes, so return here and try again
 +                // once more bytes a ready.
 +                return;
 +            }
 +
 +            if (SslHandler.isEncrypted(in))
 +            {
 +                // Connection uses SSL/TLS, replace the detection handler 
with a SslHandler and so use encryption.
 +                SslHandler sslHandler = getSslHandler("replacing optional", 
ctx.channel(), encryptionOptions);
 +                ctx.pipeline().replace(this, "ssl", sslHandler);
 +            }
 +            else
 +            {
 +                // Connection use no TLS/SSL encryption, just remove the 
detection handler and continue without
 +                // SslHandler in the pipeline.
 +                ctx.pipeline().remove(this);
 +            }
 +        }
 +    }
 +
 +    private static class RejectSslHandler extends ByteToMessageDecoder
 +    {
 +        protected void decode(ChannelHandlerContext ctx, ByteBuf in, 
List<Object> out)
 +        {
 +            if (in.readableBytes() < 5)
 +            {
 +                // To detect if SSL must be used we need to have at least 5 
bytes, so return here and try again
 +                // once more bytes a ready.
 +                return;
 +            }
 +
 +            if (SslHandler.isEncrypted(in))
 +            {
 +                logger.info("Rejected incoming TLS connection before 
negotiating from {} to {}. TLS is explicitly disabled by configuration.",
 +                            ctx.channel().remoteAddress(), 
ctx.channel().localAddress());
 +                in.readBytes(in.readableBytes()); // discard the readable 
bytes so not called again
 +                ctx.close();
 +            }
 +            else
 +            {
 +                // Incoming connection did not attempt TLS/SSL encryption, 
just remove the detection handler and continue without
 +                // SslHandler in the pipeline.
 +                ctx.pipeline().remove(this);
 +            }
 +        }
 +    }
 +}
diff --cc src/java/org/apache/cassandra/net/InboundMessageHandlers.java
index 4466466,0000000..a706557
mode 100644,000000..100644
--- a/src/java/org/apache/cassandra/net/InboundMessageHandlers.java
+++ b/src/java/org/apache/cassandra/net/InboundMessageHandlers.java
@@@ -1,447 -1,0 +1,455 @@@
 +/*
 + * Licensed to the Apache Software Foundation (ASF) under one
 + * or more contributor license agreements.  See the NOTICE file
 + * distributed with this work for additional information
 + * regarding copyright ownership.  The ASF licenses this file
 + * to you under the Apache License, Version 2.0 (the
 + * "License"); you may not use this file except in compliance
 + * with the License.  You may obtain a copy of the License at
 + *
 + *     http://www.apache.org/licenses/LICENSE-2.0
 + *
 + * Unless required by applicable law or agreed to in writing, software
 + * distributed under the License is distributed on an "AS IS" BASIS,
 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 + * See the License for the specific language governing permissions and
 + * limitations under the License.
 + */
 +package org.apache.cassandra.net;
 +
 +import java.util.Collection;
 +import java.util.concurrent.CopyOnWriteArrayList;
 +import java.util.concurrent.TimeUnit;
 +import java.util.concurrent.atomic.AtomicLongFieldUpdater;
 +import java.util.function.Consumer;
 +import java.util.function.ToLongFunction;
 +
++import com.google.common.annotations.VisibleForTesting;
++
 +import io.netty.channel.Channel;
 +import org.apache.cassandra.locator.InetAddressAndPort;
 +import org.apache.cassandra.metrics.InternodeInboundMetrics;
 +import org.apache.cassandra.net.Message.Header;
 +
 +import static java.util.concurrent.TimeUnit.NANOSECONDS;
 +import static org.apache.cassandra.utils.MonotonicClock.approxTime;
 +
 +/**
 + * An aggregation of {@link InboundMessageHandler}s for all connections from 
a peer.
 + *
 + * Manages metrics and shared resource limits. Can have multiple connections 
of a single
 + * type open simultaneousely (legacy in particular).
 + */
 +public final class InboundMessageHandlers
 +{
 +    private final InetAddressAndPort self;
 +    private final InetAddressAndPort peer;
 +
 +    private final int queueCapacity;
 +    private final ResourceLimits.Limit endpointReserveCapacity;
 +    private final ResourceLimits.Limit globalReserveCapacity;
 +
 +    private final InboundMessageHandler.WaitQueue endpointWaitQueue;
 +    private final InboundMessageHandler.WaitQueue globalWaitQueue;
 +
 +    private final InboundCounters urgentCounters = new InboundCounters();
 +    private final InboundCounters smallCounters  = new InboundCounters();
 +    private final InboundCounters largeCounters  = new InboundCounters();
 +    private final InboundCounters legacyCounters = new InboundCounters();
 +
 +    private final InboundMessageCallbacks urgentCallbacks;
 +    private final InboundMessageCallbacks smallCallbacks;
 +    private final InboundMessageCallbacks largeCallbacks;
 +    private final InboundMessageCallbacks legacyCallbacks;
 +
 +    private final InternodeInboundMetrics metrics;
 +    private final MessageConsumer messageConsumer;
 +
 +    private final HandlerProvider handlerProvider;
 +    private final Collection<InboundMessageHandler> handlers = new 
CopyOnWriteArrayList<>();
 +
 +    static class GlobalResourceLimits
 +    {
 +        final ResourceLimits.Limit reserveCapacity;
 +        final InboundMessageHandler.WaitQueue waitQueue;
 +
 +        GlobalResourceLimits(ResourceLimits.Limit reserveCapacity)
 +        {
 +            this.reserveCapacity = reserveCapacity;
 +            this.waitQueue = 
InboundMessageHandler.WaitQueue.global(reserveCapacity);
 +        }
 +    }
 +
 +    public interface MessageConsumer extends Consumer<Message<?>>
 +    {
 +        void fail(Message.Header header, Throwable failure);
 +    }
 +
 +    public interface GlobalMetricCallbacks
 +    {
 +        LatencyConsumer internodeLatencyRecorder(InetAddressAndPort to);
 +        void recordInternalLatency(Verb verb, long timeElapsed, TimeUnit 
timeUnit);
 +        void recordInternodeDroppedMessage(Verb verb, long timeElapsed, 
TimeUnit timeUnit);
 +    }
 +
 +    public InboundMessageHandlers(InetAddressAndPort self,
 +                                  InetAddressAndPort peer,
 +                                  int queueCapacity,
 +                                  long endpointReserveCapacity,
 +                                  GlobalResourceLimits globalResourceLimits,
 +                                  GlobalMetricCallbacks globalMetricCallbacks,
 +                                  MessageConsumer messageConsumer)
 +    {
 +        this(self, peer, queueCapacity, endpointReserveCapacity, 
globalResourceLimits, globalMetricCallbacks, messageConsumer, 
InboundMessageHandler::new);
 +    }
 +
 +    public InboundMessageHandlers(InetAddressAndPort self,
 +                                  InetAddressAndPort peer,
 +                                  int queueCapacity,
 +                                  long endpointReserveCapacity,
 +                                  GlobalResourceLimits globalResourceLimits,
 +                                  GlobalMetricCallbacks globalMetricCallbacks,
 +                                  MessageConsumer messageConsumer,
 +                                  HandlerProvider handlerProvider)
 +    {
 +        this.self = self;
 +        this.peer = peer;
 +
 +        this.queueCapacity = queueCapacity;
 +        this.endpointReserveCapacity = new 
ResourceLimits.Concurrent(endpointReserveCapacity);
 +        this.globalReserveCapacity = globalResourceLimits.reserveCapacity;
 +        this.endpointWaitQueue = 
InboundMessageHandler.WaitQueue.endpoint(this.endpointReserveCapacity);
 +        this.globalWaitQueue = globalResourceLimits.waitQueue;
 +        this.messageConsumer = messageConsumer;
 +
 +        this.handlerProvider = handlerProvider;
 +
 +        urgentCallbacks = makeMessageCallbacks(peer, urgentCounters, 
globalMetricCallbacks, messageConsumer);
 +        smallCallbacks  = makeMessageCallbacks(peer, smallCounters,  
globalMetricCallbacks, messageConsumer);
 +        largeCallbacks  = makeMessageCallbacks(peer, largeCounters,  
globalMetricCallbacks, messageConsumer);
 +        legacyCallbacks = makeMessageCallbacks(peer, legacyCounters, 
globalMetricCallbacks, messageConsumer);
 +
 +        metrics = new InternodeInboundMetrics(peer, this);
 +    }
 +
 +    InboundMessageHandler createHandler(FrameDecoder frameDecoder, 
ConnectionType type, Channel channel, int version)
 +    {
 +        InboundMessageHandler handler =
 +            handlerProvider.provide(frameDecoder,
 +
 +                                    type,
 +                                    channel,
 +                                    self,
 +                                    peer,
 +                                    version,
 +                                    
OutboundConnections.LARGE_MESSAGE_THRESHOLD,
 +
 +                                    queueCapacity,
 +                                    endpointReserveCapacity,
 +                                    globalReserveCapacity,
 +                                    endpointWaitQueue,
 +                                    globalWaitQueue,
 +
 +                                    this::onHandlerClosed,
 +                                    callbacksFor(type),
 +                                    messageConsumer);
 +        handlers.add(handler);
 +        return handler;
 +    }
 +
 +    void releaseMetrics()
 +    {
 +        metrics.release();
 +    }
 +
 +    private void onHandlerClosed(AbstractMessageHandler handler)
 +    {
 +        assert handler instanceof InboundMessageHandler;
 +        handlers.remove(handler);
 +        absorbCounters((InboundMessageHandler)handler);
 +    }
 +
++    @VisibleForTesting
++    public int count()
++    {
++        return handlers.size();
++    }
++
 +    /*
 +     * Message callbacks
 +     */
 +
 +    private InboundMessageCallbacks callbacksFor(ConnectionType type)
 +    {
 +        switch (type)
 +        {
 +            case URGENT_MESSAGES: return urgentCallbacks;
 +            case  SMALL_MESSAGES: return smallCallbacks;
 +            case  LARGE_MESSAGES: return largeCallbacks;
 +            case LEGACY_MESSAGES: return legacyCallbacks;
 +        }
 +
 +        throw new IllegalArgumentException();
 +    }
 +
 +    private static InboundMessageCallbacks 
makeMessageCallbacks(InetAddressAndPort peer, InboundCounters counters, 
GlobalMetricCallbacks globalMetrics, MessageConsumer messageConsumer)
 +    {
 +        LatencyConsumer internodeLatency = 
globalMetrics.internodeLatencyRecorder(peer);
 +
 +        return new InboundMessageCallbacks()
 +        {
 +            @Override
 +            public void onHeaderArrived(int messageSize, Header header, long 
timeElapsed, TimeUnit unit)
 +            {
 +                // do not log latency if we are within error bars of zero
 +                if (timeElapsed > unit.convert(approxTime.error(), 
NANOSECONDS))
 +                    internodeLatency.accept(timeElapsed, unit);
 +            }
 +
 +            @Override
 +            public void onArrived(int messageSize, Header header, long 
timeElapsed, TimeUnit unit)
 +            {
 +            }
 +
 +            @Override
 +            public void onArrivedExpired(int messageSize, Header header, 
boolean wasCorrupt, long timeElapsed, TimeUnit unit)
 +            {
 +                counters.addExpired(messageSize);
 +
 +                globalMetrics.recordInternodeDroppedMessage(header.verb, 
timeElapsed, unit);
 +            }
 +
 +            @Override
 +            public void onArrivedCorrupt(int messageSize, Header header, long 
timeElapsed, TimeUnit unit)
 +            {
 +                counters.addError(messageSize);
 +
 +                messageConsumer.fail(header, new Crc.InvalidCrc(0, 0)); // 
could use one of the original exceptions?
 +            }
 +
 +            @Override
 +            public void onClosedBeforeArrival(int messageSize, Header header, 
int bytesReceived, boolean wasCorrupt, boolean wasExpired)
 +            {
 +                counters.addError(messageSize);
 +
 +                messageConsumer.fail(header, new 
InvalidSerializedSizeException(header.verb, messageSize, bytesReceived));
 +            }
 +
 +            @Override
 +            public void onExpired(int messageSize, Header header, long 
timeElapsed, TimeUnit unit)
 +            {
 +                counters.addExpired(messageSize);
 +
 +                globalMetrics.recordInternodeDroppedMessage(header.verb, 
timeElapsed, unit);
 +            }
 +
 +            @Override
 +            public void onFailedDeserialize(int messageSize, Header header, 
Throwable t)
 +            {
 +                counters.addError(messageSize);
 +
 +                /*
 +                 * If an exception is caught during deser, return a failure 
response immediately
 +                 * instead of waiting for the callback on the other end to 
expire.
 +                 */
 +                messageConsumer.fail(header, t);
 +            }
 +
 +            @Override
 +            public void onDispatched(int messageSize, Header header)
 +            {
 +                counters.addPending(messageSize);
 +            }
 +
 +            @Override
 +            public void onExecuting(int messageSize, Header header, long 
timeElapsed, TimeUnit unit)
 +            {
 +                globalMetrics.recordInternalLatency(header.verb, timeElapsed, 
unit);
 +            }
 +
 +            @Override
 +            public void onExecuted(int messageSize, Header header, long 
timeElapsed, TimeUnit unit)
 +            {
 +                counters.removePending(messageSize);
 +            }
 +
 +            @Override
 +            public void onProcessed(int messageSize, Header header)
 +            {
 +                counters.addProcessed(messageSize);
 +            }
 +        };
 +    }
 +
 +    /*
 +     * Aggregated counters
 +     */
 +
 +    InboundCounters countersFor(ConnectionType type)
 +    {
 +        switch (type)
 +        {
 +            case URGENT_MESSAGES: return urgentCounters;
 +            case  SMALL_MESSAGES: return smallCounters;
 +            case  LARGE_MESSAGES: return largeCounters;
 +            case LEGACY_MESSAGES: return legacyCounters;
 +        }
 +
 +        throw new IllegalArgumentException();
 +    }
 +
 +    public long receivedCount()
 +    {
 +        return sumHandlers(h -> h.receivedCount) + closedReceivedCount;
 +    }
 +
 +    public long receivedBytes()
 +    {
 +        return sumHandlers(h -> h.receivedBytes) + closedReceivedBytes;
 +    }
 +
 +    public long throttledCount()
 +    {
 +        return sumHandlers(h -> h.throttledCount) + closedThrottledCount;
 +    }
 +
 +    public long throttledNanos()
 +    {
 +        return sumHandlers(h -> h.throttledNanos) + closedThrottledNanos;
 +    }
 +
 +    public long usingCapacity()
 +    {
 +        return sumHandlers(h -> h.queueSize);
 +    }
 +
 +    public long usingEndpointReserveCapacity()
 +    {
 +        return endpointReserveCapacity.using();
 +    }
 +
 +    public long corruptFramesRecovered()
 +    {
 +        return sumHandlers(h -> h.corruptFramesRecovered) + 
closedCorruptFramesRecovered;
 +    }
 +
 +    public long corruptFramesUnrecovered()
 +    {
 +        return sumHandlers(h -> h.corruptFramesUnrecovered) + 
closedCorruptFramesUnrecovered;
 +    }
 +
 +    public long errorCount()
 +    {
 +        return sumCounters(InboundCounters::errorCount);
 +    }
 +
 +    public long errorBytes()
 +    {
 +        return sumCounters(InboundCounters::errorBytes);
 +    }
 +
 +    public long expiredCount()
 +    {
 +        return sumCounters(InboundCounters::expiredCount);
 +    }
 +
 +    public long expiredBytes()
 +    {
 +        return sumCounters(InboundCounters::expiredBytes);
 +    }
 +
 +    public long processedCount()
 +    {
 +        return sumCounters(InboundCounters::processedCount);
 +    }
 +
 +    public long processedBytes()
 +    {
 +        return sumCounters(InboundCounters::processedBytes);
 +    }
 +
 +    public long scheduledCount()
 +    {
 +        return sumCounters(InboundCounters::scheduledCount);
 +    }
 +
 +    public long scheduledBytes()
 +    {
 +        return sumCounters(InboundCounters::scheduledBytes);
 +    }
 +
 +    /*
 +     * 'Archived' counter values, combined for all connections that have been 
closed.
 +     */
 +
 +    private volatile long closedReceivedCount, closedReceivedBytes;
 +
 +    private static final AtomicLongFieldUpdater<InboundMessageHandlers> 
closedReceivedCountUpdater =
 +        AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, 
"closedReceivedCount");
 +    private static final AtomicLongFieldUpdater<InboundMessageHandlers> 
closedReceivedBytesUpdater =
 +        AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, 
"closedReceivedBytes");
 +
 +    private volatile long closedThrottledCount, closedThrottledNanos;
 +
 +    private static final AtomicLongFieldUpdater<InboundMessageHandlers> 
closedThrottledCountUpdater =
 +        AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, 
"closedThrottledCount");
 +    private static final AtomicLongFieldUpdater<InboundMessageHandlers> 
closedThrottledNanosUpdater =
 +        AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, 
"closedThrottledNanos");
 +
 +    private volatile long closedCorruptFramesRecovered, 
closedCorruptFramesUnrecovered;
 +
 +    private static final AtomicLongFieldUpdater<InboundMessageHandlers> 
closedCorruptFramesRecoveredUpdater =
 +        AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, 
"closedCorruptFramesRecovered");
 +    private static final AtomicLongFieldUpdater<InboundMessageHandlers> 
closedCorruptFramesUnrecoveredUpdater =
 +        AtomicLongFieldUpdater.newUpdater(InboundMessageHandlers.class, 
"closedCorruptFramesUnrecovered");
 +
 +    private void absorbCounters(InboundMessageHandler handler)
 +    {
 +        closedReceivedCountUpdater.addAndGet(this, handler.receivedCount);
 +        closedReceivedBytesUpdater.addAndGet(this, handler.receivedBytes);
 +
 +        closedThrottledCountUpdater.addAndGet(this, handler.throttledCount);
 +        closedThrottledNanosUpdater.addAndGet(this, handler.throttledNanos);
 +
 +        closedCorruptFramesRecoveredUpdater.addAndGet(this, 
handler.corruptFramesRecovered);
 +        closedCorruptFramesUnrecoveredUpdater.addAndGet(this, 
handler.corruptFramesUnrecovered);
 +    }
 +
 +    private long sumHandlers(ToLongFunction<InboundMessageHandler> counter)
 +    {
 +        long sum = 0L;
 +        for (InboundMessageHandler h : handlers)
 +            sum += counter.applyAsLong(h);
 +        return sum;
 +    }
 +
 +    private long sumCounters(ToLongFunction<InboundCounters> mapping)
 +    {
 +        return mapping.applyAsLong(urgentCounters)
 +             + mapping.applyAsLong(smallCounters)
 +             + mapping.applyAsLong(largeCounters)
 +             + mapping.applyAsLong(legacyCounters);
 +    }
 +
 +    interface HandlerProvider
 +    {
 +        InboundMessageHandler provide(FrameDecoder decoder,
 +
 +                                      ConnectionType type,
 +                                      Channel channel,
 +                                      InetAddressAndPort self,
 +                                      InetAddressAndPort peer,
 +                                      int version,
 +                                      int largeMessageThreshold,
 +
 +                                      int queueCapacity,
 +                                      ResourceLimits.Limit 
endpointReserveCapacity,
 +                                      ResourceLimits.Limit 
globalReserveCapacity,
 +                                      InboundMessageHandler.WaitQueue 
endpointWaitQueue,
 +                                      InboundMessageHandler.WaitQueue 
globalWaitQueue,
 +
 +                                      InboundMessageHandler.OnHandlerClosed 
onClosed,
 +                                      InboundMessageCallbacks callbacks,
 +                                      Consumer<Message<?>> consumer);
 +    }
 +}
diff --cc 
test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
index 0000000,5c0e3b3..86a7c99
mode 000000,100644..100644
--- 
a/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
+++ 
b/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
@@@ -1,0 -1,137 +1,134 @@@
+ /*
+  * Licensed to the Apache Software Foundation (ASF) under one
+  * or more contributor license agreements.  See the NOTICE file
+  * distributed with this work for additional information
+  * regarding copyright ownership.  The ASF licenses this file
+  * to you under the Apache License, Version 2.0 (the
+  * "License"); you may not use this file except in compliance
+  * with the License.  You may obtain a copy of the License at
+  *
+  *     http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+ package org.apache.cassandra.distributed.test;
+ 
+ import java.util.HashMap;
 -import java.util.List;
+ 
+ import com.google.common.collect.ImmutableMap;
+ import org.junit.Test;
+ 
+ import org.apache.cassandra.distributed.Cluster;
+ import org.apache.cassandra.distributed.api.Feature;
+ import 
org.apache.cassandra.distributed.api.IIsolatedExecutor.SerializableRunnable;
+ import org.apache.cassandra.distributed.shared.NetworkTopology;
++import org.apache.cassandra.net.InboundMessageHandlers;
+ import org.apache.cassandra.net.MessagingService;
++import org.apache.cassandra.net.OutboundConnections;
+ 
+ import static com.google.common.collect.Iterables.getOnlyElement;
+ import static org.junit.Assert.assertEquals;
++import static org.junit.Assert.assertFalse;
+ import static org.junit.Assert.assertTrue;
+ 
+ public final class InternodeEncryptionEnforcementTest extends TestBaseImpl
+ {
+     @Test
+     public void testConnectionsAreRejectedWithInvalidConfig() throws Throwable
+     {
+         Cluster.Builder builder = builder()
+             .withNodes(2)
+             .withConfig(c ->
+             {
+                 c.with(Feature.NETWORK);
+                 c.with(Feature.NATIVE_PROTOCOL);
+ 
+                 if (c.num() == 1)
+                 {
+                     HashMap<String, Object> encryption = new HashMap<>();
+                     encryption.put("keystore", 
"test/conf/cassandra_ssl_test.keystore");
+                     encryption.put("keystore_password", "cassandra");
+                     encryption.put("truststore", 
"test/conf/cassandra_ssl_test.truststore");
+                     encryption.put("truststore_password", "cassandra");
+                     encryption.put("internode_encryption", "dc");
+                     c.set("server_encryption_options", encryption);
+                 }
+             })
+             .withNodeIdTopology(ImmutableMap.of(1, 
NetworkTopology.dcAndRack("dc1", "r1a"),
+                                                 2, 
NetworkTopology.dcAndRack("dc2", "r2a")));
+ 
+         try (Cluster cluster = builder.start())
+         {
+             /*
+              * instance (1) won't connect to (2), since (2) won't have a TLS 
listener;
+              * instance (2) won't connect to (1), since inbound check will 
reject
+              * the unencrypted connection attempt;
+              *
+              * without the patch, instance (2) *CAN* connect to (1), without 
encryption,
+              * despite being in a different dc.
+              */
+ 
+             cluster.get(1).runOnInstance(() ->
+             {
 -                List<MessagingService.SocketThread> threads = 
MessagingService.instance().getSocketThreads();
 -                assertEquals(2, threads.size());
++                InboundMessageHandlers inbound = 
getOnlyElement(MessagingService.instance().messageHandlers.values());
++                assertEquals(0, inbound.count());
+ 
 -                for (MessagingService.SocketThread thread : threads)
 -                {
 -                    assertEquals(0, thread.connections.size());
 -                }
++                OutboundConnections outbound = 
getOnlyElement(MessagingService.instance().channelManagers.values());
++                assertFalse(outbound.small.isConnected() || 
outbound.large.isConnected() || outbound.urgent.isConnected());
+             });
+ 
+             cluster.get(2).runOnInstance(() ->
+             {
 -                List<MessagingService.SocketThread> threads = 
MessagingService.instance().getSocketThreads();
 -                assertEquals(1, threads.size());
 -                assertTrue(getOnlyElement(threads).connections.isEmpty());
++                
assertTrue(MessagingService.instance().messageHandlers.isEmpty());
++
++                OutboundConnections outbound = 
getOnlyElement(MessagingService.instance().channelManagers.values());
++                assertFalse(outbound.small.isConnected() || 
outbound.large.isConnected() || outbound.urgent.isConnected());
+             });
+         }
+     }
+ 
+     @Test
+     public void testConnectionsAreAcceptedWithValidConfig() throws Throwable
+     {
+         Cluster.Builder builder = builder()
+             .withNodes(2)
+             .withConfig(c ->
+             {
+                 c.with(Feature.NETWORK);
+                 c.with(Feature.NATIVE_PROTOCOL);
+ 
 -                HashMap<String, Object> encryption = new HashMap<>();
 -                encryption.put("keystore", 
"test/conf/cassandra_ssl_test.keystore");
++                HashMap<String, Object> encryption = new HashMap<>(); 
encryption.put("keystore", "test/conf/cassandra_ssl_test.keystore");
+                 encryption.put("keystore_password", "cassandra");
+                 encryption.put("truststore", 
"test/conf/cassandra_ssl_test.truststore");
+                 encryption.put("truststore_password", "cassandra");
+                 encryption.put("internode_encryption", "dc");
+                 c.set("server_encryption_options", encryption);
+             })
+             .withNodeIdTopology(ImmutableMap.of(1, 
NetworkTopology.dcAndRack("dc1", "r1a"),
+                                                 2, 
NetworkTopology.dcAndRack("dc2", "r2a")));
+ 
+         try (Cluster cluster = builder.start())
+         {
+             /*
+              * instance (1) should connect to instance (2) without any issues;
+              * instance (2) should connect to instance (1) without any issues.
+              */
+ 
+             SerializableRunnable runnable = () ->
+             {
 -                List<MessagingService.SocketThread> threads = 
MessagingService.instance().getSocketThreads();
 -                assertEquals(2, threads.size());
 -
 -                MessagingService.SocketThread sslThread = threads.get(0);
 -                assertEquals(1, sslThread.connections.size());
++                InboundMessageHandlers inbound = 
getOnlyElement(MessagingService.instance().messageHandlers.values());
++                assertTrue(inbound.count() > 0);
+ 
 -                MessagingService.SocketThread plainThread = threads.get(1);
 -                assertEquals(0, plainThread.connections.size());
++                OutboundConnections outbound = 
getOnlyElement(MessagingService.instance().channelManagers.values());
++                assertTrue(outbound.small.isConnected() || 
outbound.large.isConnected() || outbound.urgent.isConnected());
+             };
+ 
+             cluster.get(1).runOnInstance(runnable);
+             cluster.get(2).runOnInstance(runnable);
+         }
+     }
+ }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to