This is an automated email from the ASF dual-hosted git repository.

samt pushed a commit to branch cassandra-4.1
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/cassandra-4.1 by this push:
     new 6f90e962f5 Enforce CQL message size limit on multiframe messages
6f90e962f5 is described below

commit 6f90e962f54c4b1a90ad6c3dc0bb6a224843abf0
Author: Dmitry Konstantinov <netud...@gmail.com>
AuthorDate: Mon Nov 4 21:17:24 2024 +0000

    Enforce CQL message size limit on multiframe messages
    
    Patch by Dmitry Konstantinov; reviewed by Sam Tunnicliffe,
    Caleb Rackliffe for CASSANDRA-20052
---
 CHANGES.txt                                        |   1 +
 src/java/org/apache/cassandra/config/Config.java   |   1 +
 .../cassandra/config/DatabaseDescriptor.java       |  52 +++++++
 .../exceptions/OversizedCQLMessageException.java   |  27 ++++
 .../cassandra/net/AbstractMessageHandler.java      |   4 +-
 .../cassandra/transport/CQLMessageHandler.java     | 110 ++++++++++++--
 .../cassandra/transport/ExceptionHandlers.java     |   5 +
 .../transport/InitialConnectionHandler.java        |   3 +-
 .../cassandra/transport/PipelineConfigurator.java  |   2 +
 .../apache/cassandra/transport/SimpleClient.java   |   1 +
 .../transport/AuthMessageSizeLimitTest.java        | 104 +++++++++++++
 .../transport/ClientResourceLimitsTest.java        | 106 ++++---------
 .../cassandra/transport/MessageSizeLimitTest.java  | 124 ++++++++++++++++
 .../transport/NativeProtocolLimitsTestBase.java    | 165 +++++++++++++++++++++
 .../cassandra/transport/RateLimitingTest.java      |  69 +++------
 15 files changed, 632 insertions(+), 142 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index 3e5597558d..45ef4fe379 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 4.1.8
+ * Enforce CQL message size limit on multiframe messages (CASSANDRA-20052)
  * Add nodetool checktokenmetadata command that checks TokenMetadata is insync 
with Gossip endpointState (CASSANDRA-18758)
  * Backport Java 11 support for Simulator (CASSANDRA-17178/CASSANDRA-19935)
  * Equality check for Paxos.Electorate should not depend on collection types 
(CASSANDRA-19935)
diff --git a/src/java/org/apache/cassandra/config/Config.java 
b/src/java/org/apache/cassandra/config/Config.java
index d8d1f7e617..b64ff079de 100644
--- a/src/java/org/apache/cassandra/config/Config.java
+++ b/src/java/org/apache/cassandra/config/Config.java
@@ -269,6 +269,7 @@ public class Config
     public int native_transport_max_threads = 128;
     @Replaces(oldName = "native_transport_max_frame_size_in_mb", converter = 
Converters.MEBIBYTES_DATA_STORAGE_INT, deprecated = true)
     public DataStorageSpec.IntMebibytesBound native_transport_max_frame_size = 
new DataStorageSpec.IntMebibytesBound("16MiB");
+    public volatile DataStorageSpec.LongBytesBound 
native_transport_max_message_size = null;
     /** do bcrypt hashing in a limited pool to prevent cpu load spikes; 0 
means that all requests will go to default request executor**/
     public int native_transport_max_auth_threads = 0;
     public volatile long native_transport_max_concurrent_connections = -1L;
diff --git a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java 
b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java
index b85a680cfa..3ff03ce801 100644
--- a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java
+++ b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java
@@ -153,6 +153,9 @@ public class DatabaseDescriptor
     private static long counterCacheSizeInMiB;
     private static long indexSummaryCapacityInMiB;
 
+    private static volatile long nativeTransportMaxMessageSizeInBytes;
+    private static volatile boolean 
nativeTransportMaxMessageSizeConfiguredExplicitly;
+
     private static String localDC;
     private static Comparator<Replica> localComparator;
     private static EncryptionContext encryptionContext;
@@ -822,6 +825,23 @@ public class DatabaseDescriptor
         else if (conf.commitlog_segment_size.toKibibytes() < 2 * 
conf.max_mutation_size.toKibibytes())
             throw new ConfigurationException("commitlog_segment_size must be 
at least twice the size of max_mutation_size / 1024", false);
 
+        if (conf.native_transport_max_message_size == null)
+        {
+            conf.native_transport_max_message_size = new 
DataStorageSpec.LongBytesBound(calculateDefaultNativeTransportMaxMessageSizeInBytes());
+        }
+        else
+        {
+            nativeTransportMaxMessageSizeConfiguredExplicitly = true;
+            long maxCqlMessageSize = 
conf.native_transport_max_message_size.toBytes();
+            if (maxCqlMessageSize > 
conf.native_transport_max_request_data_in_flight.toBytes())
+                throw new 
ConfigurationException("native_transport_max_message_size must not exceed 
native_transport_max_request_data_in_flight", false);
+
+            if (maxCqlMessageSize > 
conf.native_transport_max_request_data_in_flight_per_ip.toBytes())
+                throw new 
ConfigurationException("native_transport_max_message_size must not exceed 
native_transport_max_request_data_in_flight_per_ip", false);
+
+        }
+        nativeTransportMaxMessageSizeInBytes = 
conf.native_transport_max_message_size.toBytes();
+
         // native transport encryption options
         if (conf.client_encryption_options != null)
         {
@@ -2748,6 +2768,30 @@ public class DatabaseDescriptor
         return 
conf.native_transport_max_request_data_in_flight_per_ip.toBytes();
     }
 
+    public static long getNativeTransportMaxMessageSizeInBytes()
+    {
+        // the value of native_transport_max_message_size in bytes is cached
+        // to avoid conversion overhead during a parsing of each incoming CQL 
message
+        return nativeTransportMaxMessageSizeInBytes;
+    }
+
+    @VisibleForTesting
+    public static void setNativeTransportMaxMessageSizeInBytes(long 
maxMessageSizeInBytes)
+    {
+        conf.native_transport_max_message_size = new 
DataStorageSpec.LongBytesBound(maxMessageSizeInBytes);
+        nativeTransportMaxMessageSizeInBytes = 
conf.native_transport_max_message_size.toBytes();
+    }
+
+    private static long calculateDefaultNativeTransportMaxMessageSizeInBytes()
+    {
+        return Math.min(conf.max_mutation_size.toBytes(),
+                   Math.min(
+                   conf.native_transport_max_request_data_in_flight.toBytes(),
+                   
conf.native_transport_max_request_data_in_flight_per_ip.toBytes()
+                   )
+        );
+    }
+
     public static Config.PaxosVariant getPaxosVariant()
     {
         return conf.paxos_variant;
@@ -2874,6 +2918,10 @@ public class DatabaseDescriptor
             maxRequestDataInFlightInBytes = Runtime.getRuntime().maxMemory() / 
40;
 
         conf.native_transport_max_request_data_in_flight_per_ip = new 
DataStorageSpec.LongBytesBound(maxRequestDataInFlightInBytes);
+        long newNativeTransportMaxMessageSizeInBytes = 
nativeTransportMaxMessageSizeConfiguredExplicitly
+                                                       ? 
Math.min(maxRequestDataInFlightInBytes, 
getNativeTransportMaxMessageSizeInBytes())
+                                                       : 
calculateDefaultNativeTransportMaxMessageSizeInBytes();
+        
setNativeTransportMaxMessageSizeInBytes(newNativeTransportMaxMessageSizeInBytes);
     }
 
     public static long getNativeTransportMaxRequestDataInFlightInBytes()
@@ -2887,6 +2935,10 @@ public class DatabaseDescriptor
             maxRequestDataInFlightInBytes = Runtime.getRuntime().maxMemory() / 
10;
 
         conf.native_transport_max_request_data_in_flight = new 
DataStorageSpec.LongBytesBound(maxRequestDataInFlightInBytes);
+        long newNativeTransportMaxMessageSizeInBytes = 
nativeTransportMaxMessageSizeConfiguredExplicitly
+                                                       ? 
Math.min(maxRequestDataInFlightInBytes, 
getNativeTransportMaxMessageSizeInBytes())
+                                                       : 
calculateDefaultNativeTransportMaxMessageSizeInBytes();
+        
setNativeTransportMaxMessageSizeInBytes(newNativeTransportMaxMessageSizeInBytes);
     }
 
     public static int getNativeTransportMaxRequestsPerSecond()
diff --git 
a/src/java/org/apache/cassandra/exceptions/OversizedCQLMessageException.java 
b/src/java/org/apache/cassandra/exceptions/OversizedCQLMessageException.java
new file mode 100644
index 0000000000..60f9c48835
--- /dev/null
+++ b/src/java/org/apache/cassandra/exceptions/OversizedCQLMessageException.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.exceptions;
+
+public class OversizedCQLMessageException extends InvalidRequestException
+{
+    public OversizedCQLMessageException(String message)
+    {
+        super(message);
+    }
+}
diff --git a/src/java/org/apache/cassandra/net/AbstractMessageHandler.java 
b/src/java/org/apache/cassandra/net/AbstractMessageHandler.java
index e2cf68d6d1..5b5b8b7f1a 100644
--- a/src/java/org/apache/cassandra/net/AbstractMessageHandler.java
+++ b/src/java/org/apache/cassandra/net/AbstractMessageHandler.java
@@ -562,7 +562,7 @@ public abstract class AbstractMessageHandler extends 
ChannelInboundHandlerAdapte
             return size == received;
         }
 
-        private void onIntactFrame(IntactFrame frame)
+        protected void onIntactFrame(IntactFrame frame)
         {
             boolean expires = approxTime.isAfter(expiresAtNanos);
             if (!isExpired && !isCorrupt)
@@ -578,7 +578,7 @@ public abstract class AbstractMessageHandler extends 
ChannelInboundHandlerAdapte
             isExpired |= expires;
         }
 
-        private void onCorruptFrame()
+        protected void onCorruptFrame()
         {
             if (!isExpired && !isCorrupt)
                 releaseBuffersAndCapacity(); // release resources once we 
transition from normal state to corrupt
diff --git a/src/java/org/apache/cassandra/transport/CQLMessageHandler.java 
b/src/java/org/apache/cassandra/transport/CQLMessageHandler.java
index 65c0282908..792f6bf7b6 100644
--- a/src/java/org/apache/cassandra/transport/CQLMessageHandler.java
+++ b/src/java/org/apache/cassandra/transport/CQLMessageHandler.java
@@ -31,6 +31,7 @@ import io.netty.buffer.Unpooled;
 import io.netty.channel.Channel;
 import org.apache.cassandra.config.DatabaseDescriptor;
 import org.apache.cassandra.exceptions.OverloadedException;
+import org.apache.cassandra.exceptions.OversizedCQLMessageException;
 import org.apache.cassandra.metrics.ClientMessageSizeMetrics;
 import org.apache.cassandra.metrics.ClientMetrics;
 import org.apache.cassandra.net.AbstractMessageHandler;
@@ -81,6 +82,10 @@ public class CQLMessageHandler<M extends Message> extends 
AbstractMessageHandler
     public static final int LARGE_MESSAGE_THRESHOLD = 
FrameEncoder.Payload.MAX_SIZE - 1;
     public static final TimeUnit RATE_LIMITER_DELAY_UNIT = 
TimeUnit.NANOSECONDS;
 
+    static final String MULTI_FRAME_AUTH_ERROR_MESSAGE_PREFIX = "The 
connection is not yet in a valid state " +
+                                                                "to process 
multi frame CQL Messages, usually this" +
+                                                                "means that 
authentication is still pending. ";
+
     private final QueueBackpressure queueBackpressure;
     private final Envelope.Decoder envelopeDecoder;
     private final Message.Decoder<M> messageDecoder;
@@ -94,6 +99,8 @@ public class CQLMessageHandler<M extends Message> extends 
AbstractMessageHandler
     long channelPayloadBytesInFlight;
     private int consecutiveMessageErrors = 0;
 
+    private final ServerConnection serverConnection;
+
     interface MessageConsumer<M extends Message>
     {
         void dispatch(Channel channel, M message, 
Dispatcher.FlushItemConverter toFlushItem, Overload backpressure);
@@ -106,6 +113,7 @@ public class CQLMessageHandler<M extends Message> extends 
AbstractMessageHandler
     }
 
     CQLMessageHandler(Channel channel,
+                      ServerConnection serverConnection,
                       ProtocolVersion version,
                       FrameDecoder decoder,
                       Envelope.Decoder envelopeDecoder,
@@ -128,6 +136,7 @@ public class CQLMessageHandler<M extends Message> extends 
AbstractMessageHandler
               resources.endpointWaitQueue(),
               resources.globalWaitQueue(),
               onClosed);
+        this.serverConnection = serverConnection;
         this.envelopeDecoder    = envelopeDecoder;
         this.messageDecoder     = messageDecoder;
         this.payloadAllocator   = payloadAllocator;
@@ -518,10 +527,28 @@ public class CQLMessageHandler<M extends Message> extends 
AbstractMessageHandler
             // max CQL message size defaults to 256mb, so should be safe to 
downcast
             int messageSize = Ints.checkedCast(header.bodySizeInBytes);
             receivedBytes += buf.remaining();
+
+            if (serverConnection != null && serverConnection.stage() != 
ConnectionStage.READY)
+            {
+                // Disallow any multiframe messages before the connection 
reaches the READY state.
+                // This guards against being swamped with oversize messages 
from unauthenticated
+                // clients. In this case, we raise a fatal error and close the 
connection so it does
+                // not make sense to continue processing subsequent frames
+                handleError(ProtocolException.toFatalException(new 
OversizedAuthMessageException(
+                            MULTI_FRAME_AUTH_ERROR_MESSAGE_PREFIX +
+                            "type = " + header.type + ", size = " + 
header.bodySizeInBytes)));
+                ClientMetrics.instance.markRequestDiscarded();
+                return false;
+            }
             
             LargeMessage largeMessage = new LargeMessage(header);
-
-            if (throwOnOverload)
+            if (messageSize > 
DatabaseDescriptor.getNativeTransportMaxMessageSizeInBytes())
+            {
+                ClientMetrics.instance.markRequestDiscarded();
+                // Mark as too big so that discard the message after consuming 
any subsequent frames
+                largeMessage.markTooBig();
+            }
+            else if (throwOnOverload)
             {
                 if (!acquireCapacity(header, endpointReserve, globalReserve))
                 {
@@ -567,9 +594,9 @@ public class CQLMessageHandler<M extends Message> extends 
AbstractMessageHandler
                     }
                 }
             }
-            else
+            else // throwOnOverload = false
             {
-                if (acquireCapacity(header, endpointReserve, globalReserve))
+                if (acquireCapacityAndQueueOnFailure(header, endpointReserve, 
globalReserve))
                 {
                     long delay = -1;
                     Overload backpressure = Overload.NONE;
@@ -605,7 +632,14 @@ public class CQLMessageHandler<M extends Message> extends 
AbstractMessageHandler
                 }
                 else
                 {
-                    noSpamLogger.error("Could not aquire capacity while 
processing native protocol message");
+                    // we checked previously that messageSize <= 
native_transport_max_message_size
+                    // and native_transport_max_message_size <= 
native_transport_max_request_data_in_flight
+                    // and native_transport_max_message_size <= 
native_transport_max_request_data_in_flight_per_ip
+                    // so, a starvation is not possible for the following case:
+                    // a connection is blocked forever if somebody tries to 
send a single too big message > total rate limiting capacity.
+                    // Once other messages in the same or other CQL 
connections are processed and capacity is returned to the limits
+                    // we have enough capacity to acquire it for the current 
large message.
+                    return false;
                 }
             }
 
@@ -712,6 +746,7 @@ public class CQLMessageHandler<M extends Message> extends 
AbstractMessageHandler
 
         private Overload overload = Overload.NONE;
         private Overload backpressure = Overload.NONE;
+        private boolean tooBig = false;
 
         private LargeMessage(Envelope.Header header)
         {
@@ -747,18 +782,75 @@ public class CQLMessageHandler<M extends Message> extends 
AbstractMessageHandler
             this.backpressure = backpressure;
         }
 
+        private void markTooBig()
+        {
+            this.tooBig = true;
+        }
+
+        @Override
+        protected void onIntactFrame(IntactFrame frame)
+        {
+            if (tooBig || overload != Overload.NONE)
+                // we do not want to add the frame to buffers (to not consume 
a lot of memory and throw it away later
+                // we also do not want to release capacity because we haven't 
accuired it
+                frame.consume();
+            else
+                super.onIntactFrame(frame);
+        }
+
+        @Override
+        protected void onCorruptFrame()
+        {
+            if (!isExpired && !isCorrupt && !tooBig)
+            {
+                releaseBuffers(); // release resources once we transition from 
normal state to corrupt
+                if (overload != Overload.BYTES_IN_FLIGHT)
+                    releaseCapacity(size);
+            }
+            isCorrupt = true;
+            isExpired |= approxTime.isAfter(expiresAtNanos);
+        }
+
+
+        @Override
         protected void onComplete()
         {
-            if (overload != Overload.NONE)
-                
handleErrorAndRelease(buildOverloadedException(endpointReserveCapacity, 
globalReserveCapacity, overload), header);
+            if (tooBig)
+                // we haven't accuired a capacity for too big messages to 
release it
+                
handleError(buildOversizedCQLMessageException(header.bodySizeInBytes), header);
+            else if (overload != Overload.NONE)
+                if (overload == Overload.BYTES_IN_FLIGHT)
+                    // we haven't accuired a capacity successfully to release 
it
+                    
handleError(buildOverloadedException(endpointReserveCapacity, 
globalReserveCapacity, overload), header);
+                else
+                    
handleErrorAndRelease(buildOverloadedException(endpointReserveCapacity, 
globalReserveCapacity, overload), header);
             else if (!isCorrupt)
                 processRequest(assembleFrame(), backpressure);
         }
 
+        @Override
         protected void abort()
         {
-            if (!isCorrupt)
-                releaseBuffersAndCapacity(); // release resources if in normal 
state when abort() is invoked
+            if (!isCorrupt && !tooBig && overload == Overload.NONE)
+                releaseBuffers();
+
+            if (overload == Overload.NONE || overload == 
Overload.BYTES_IN_FLIGHT)
+                releaseCapacity(size);
+        }
+
+        private OversizedCQLMessageException 
buildOversizedCQLMessageException(long messageBodySize)
+        {
+            return new OversizedCQLMessageException("CQL Message of size " + 
messageBodySize
+                                                    + " bytes exceeds allowed 
maximum of "
+                                                    + 
DatabaseDescriptor.getNativeTransportMaxMessageSizeInBytes() + " bytes");
+        }
+    }
+
+    static class OversizedAuthMessageException extends ProtocolException
+    {
+        OversizedAuthMessageException(String message)
+        {
+            super(message);
         }
     }
 }
diff --git a/src/java/org/apache/cassandra/transport/ExceptionHandlers.java 
b/src/java/org/apache/cassandra/transport/ExceptionHandlers.java
index 4f063924ea..4d36fa6cd7 100644
--- a/src/java/org/apache/cassandra/transport/ExceptionHandlers.java
+++ b/src/java/org/apache/cassandra/transport/ExceptionHandlers.java
@@ -38,6 +38,7 @@ import io.netty.channel.ChannelPromise;
 import io.netty.channel.unix.Errors;
 import org.apache.cassandra.exceptions.OverloadedException;
 import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.exceptions.OversizedCQLMessageException;
 import org.apache.cassandra.metrics.ClientMetrics;
 import org.apache.cassandra.net.FrameEncoder;
 import org.apache.cassandra.transport.messages.ErrorMessage;
@@ -130,6 +131,10 @@ public class ExceptionHandlers
             // Once the threshold for overload is breached, it will very 
likely spam the logs...
             NoSpamLogger.log(logger, NoSpamLogger.Level.INFO, 1, 
TimeUnit.MINUTES, cause.getMessage());
         }
+        else if (Throwables.anyCauseMatches(cause, t -> t instanceof 
OversizedCQLMessageException))
+        {
+            NoSpamLogger.log(logger, NoSpamLogger.Level.INFO, 1, 
TimeUnit.MINUTES, cause.getMessage());
+        }
         else if (Throwables.anyCauseMatches(cause, t -> t instanceof 
Errors.NativeIoException))
         {
             ClientMetrics.instance.markUnknownException();
diff --git 
a/src/java/org/apache/cassandra/transport/InitialConnectionHandler.java 
b/src/java/org/apache/cassandra/transport/InitialConnectionHandler.java
index 576af3e6dc..463d824c4c 100644
--- a/src/java/org/apache/cassandra/transport/InitialConnectionHandler.java
+++ b/src/java/org/apache/cassandra/transport/InitialConnectionHandler.java
@@ -104,6 +104,7 @@ public class InitialConnectionHandler extends 
ByteToMessageDecoder
                         attrConn.set(connection);
                     }
                     assert connection instanceof ServerConnection;
+                    ServerConnection serverConnection = (ServerConnection) 
connection;
 
                     StartupMessage startup = (StartupMessage) 
Message.Decoder.decodeMessage(ctx.channel(), inbound);
                     InetAddress remoteAddress = ((InetSocketAddress) 
ctx.channel().remoteAddress()).getAddress();
@@ -120,7 +121,7 @@ public class InitialConnectionHandler extends 
ByteToMessageDecoder
                             if (future.isSuccess())
                             {
                                 logger.trace("Response to STARTUP sent, 
configuring pipeline for {}", inbound.header.version);
-                                configurator.configureModernPipeline(ctx, 
allocator, inbound.header.version, startup.options);
+                                configurator.configureModernPipeline(ctx, 
serverConnection, allocator, inbound.header.version, startup.options);
                                 
allocator.release(inbound.header.bodySizeInBytes);
                             }
                             else
diff --git a/src/java/org/apache/cassandra/transport/PipelineConfigurator.java 
b/src/java/org/apache/cassandra/transport/PipelineConfigurator.java
index 10ca818719..6ad00471c7 100644
--- a/src/java/org/apache/cassandra/transport/PipelineConfigurator.java
+++ b/src/java/org/apache/cassandra/transport/PipelineConfigurator.java
@@ -270,6 +270,7 @@ public class PipelineConfigurator
     }
 
     public void configureModernPipeline(ChannelHandlerContext ctx,
+                                        ServerConnection serverConnection,
                                         ClientResourceLimits.Allocator 
resourceAllocator,
                                         ProtocolVersion version,
                                         Map<String, String> options)
@@ -311,6 +312,7 @@ public class PipelineConfigurator
         CQLMessageHandler.MessageConsumer<Message.Request> messageConsumer = 
messageConsumer();
         CQLMessageHandler<Message.Request> processor =
             new CQLMessageHandler<>(ctx.channel(),
+                                    serverConnection,
                                     version,
                                     frameDecoder,
                                     envelopeDecoder,
diff --git a/src/java/org/apache/cassandra/transport/SimpleClient.java 
b/src/java/org/apache/cassandra/transport/SimpleClient.java
index f9c9b7f5d7..a7227c89a5 100644
--- a/src/java/org/apache/cassandra/transport/SimpleClient.java
+++ b/src/java/org/apache/cassandra/transport/SimpleClient.java
@@ -504,6 +504,7 @@ public class SimpleClient implements Closeable
 
             CQLMessageHandler<Message.Response> processor =
                 new CQLMessageHandler<Message.Response>(ctx.channel(),
+                                        null,
                                         version,
                                         frameDecoder,
                                         envelopeDecoder,
diff --git 
a/test/unit/org/apache/cassandra/transport/AuthMessageSizeLimitTest.java 
b/test/unit/org/apache/cassandra/transport/AuthMessageSizeLimitTest.java
new file mode 100644
index 0000000000..5eaee511d5
--- /dev/null
+++ b/test/unit/org/apache/cassandra/transport/AuthMessageSizeLimitTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.transport;
+
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import com.datastax.driver.core.Authenticator;
+import com.datastax.driver.core.EndPoint;
+import com.datastax.driver.core.PlainTextAuthProvider;
+import org.apache.cassandra.Util;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.net.FrameEncoder;
+import org.apache.cassandra.transport.messages.AuthResponse;
+import org.apache.cassandra.transport.messages.QueryMessage;
+import org.assertj.core.api.Assertions;
+
+public class AuthMessageSizeLimitTest extends NativeProtocolLimitsTestBase
+{
+    private static final int TOO_BIG_MULTI_FRAME_AUTH_MESSAGE_SIZE = 2 * 
FrameEncoder.Payload.MAX_SIZE;
+
+    // set MAX_CQL_MESSAGE_SIZE bigger than 
TOO_BIG_MULTI_FRAME_AUTH_MESSAGE_SIZE to ensure what the auth message size 
check is more restrictive
+    private static final int MAX_CQL_MESSAGE_SIZE = 
TOO_BIG_MULTI_FRAME_AUTH_MESSAGE_SIZE * 2;
+
+    @BeforeClass
+    public static void setUp()
+    {
+        DatabaseDescriptor.setNativeTransportReceiveQueueCapacityInBytes(1);
+        
DatabaseDescriptor.setNativeTransportMaxRequestDataInFlightPerIpInBytes(MAX_CQL_MESSAGE_SIZE);
+        
DatabaseDescriptor.setNativeTransportConcurrentRequestDataInFlightInBytes(MAX_CQL_MESSAGE_SIZE);
+        
DatabaseDescriptor.setNativeTransportMaxMessageSizeInBytes(MAX_CQL_MESSAGE_SIZE);
+        requireNetwork();
+        requireAuthentication();
+    }
+
+    @Before
+    public void setLimits()
+    {
+        ClientResourceLimits.setGlobalLimit(MAX_CQL_MESSAGE_SIZE);
+        ClientResourceLimits.setEndpointLimit(MAX_CQL_MESSAGE_SIZE);
+    }
+
+    @Test
+    public void sendSmallAuthMessage()
+    {
+        doTest((client) ->
+               {
+                   AuthResponse authResponse = createAuthMessage("cassandra", 
"cassandra");
+                   client.execute(authResponse);
+                   createTable(client);
+
+                   int valueLessThanMessageMaxSize = MAX_CQL_MESSAGE_SIZE - 
500;
+                   QueryMessage queryMessage = 
queryMessage(valueLessThanMessageMaxSize);
+                   client.execute(queryMessage);
+               }
+        );
+    }
+
+    @Test
+    public void sendTooBigAuthMultiFrameMessage()
+    {
+        doTest((client) ->
+               {
+                   AuthResponse authResponse = createAuthMessage("cassandra", 
createIncorrectLongPassword(TOO_BIG_MULTI_FRAME_AUTH_MESSAGE_SIZE));
+                   Assertions.assertThatThrownBy(() -> 
client.execute(authResponse))
+                             .hasCauseInstanceOf(ProtocolException.class)
+                             
.hasMessageContaining(CQLMessageHandler.MULTI_FRAME_AUTH_ERROR_MESSAGE_PREFIX);
+                   Util.spinAssertEquals(false, () -> 
client.connection.channel().isOpen(), 10);
+               }
+        );
+    }
+
+    private static String createIncorrectLongPassword(int length)
+    {
+        StringBuilder password = new StringBuilder(length);
+        for (int i = 0; i < length; i++)
+            password.append('a');
+        return password.toString();
+    }
+
+    private AuthResponse createAuthMessage(String username, String password)
+    {
+        PlainTextAuthProvider authProvider = new 
PlainTextAuthProvider(username, password);
+        Authenticator authenticator = authProvider.newAuthenticator((EndPoint) 
null, null);
+        return new AuthResponse(authenticator.initialResponse());
+    }
+}
diff --git 
a/test/unit/org/apache/cassandra/transport/ClientResourceLimitsTest.java 
b/test/unit/org/apache/cassandra/transport/ClientResourceLimitsTest.java
index 8e94997239..1407ab05ac 100644
--- a/test/unit/org/apache/cassandra/transport/ClientResourceLimitsTest.java
+++ b/test/unit/org/apache/cassandra/transport/ClientResourceLimitsTest.java
@@ -18,7 +18,6 @@
 
 package org.apache.cassandra.transport;
 
-import java.io.IOException;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.TimeUnit;
@@ -32,9 +31,6 @@ import org.apache.cassandra.service.StorageService;
 import org.junit.*;
 
 import org.apache.cassandra.config.DatabaseDescriptor;
-import org.apache.cassandra.cql3.CQLTester;
-import org.apache.cassandra.cql3.QueryOptions;
-import org.apache.cassandra.cql3.QueryProcessor;
 import org.apache.cassandra.db.marshal.Int32Type;
 import org.apache.cassandra.db.marshal.UTF8Type;
 import org.apache.cassandra.db.virtual.*;
@@ -50,21 +46,11 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
-public class ClientResourceLimitsTest extends CQLTester
+public class ClientResourceLimitsTest extends NativeProtocolLimitsTestBase
 {
     private static final long LOW_LIMIT = 600L;
     private static final long HIGH_LIMIT = 5000000000L;
 
-    private static final QueryOptions V5_DEFAULT_OPTIONS = 
-        QueryOptions.create(QueryOptions.DEFAULT.getConsistency(),
-                            QueryOptions.DEFAULT.getValues(),
-                            QueryOptions.DEFAULT.skipMetadata(),
-                            QueryOptions.DEFAULT.getPageSize(),
-                            QueryOptions.DEFAULT.getPagingState(),
-                            QueryOptions.DEFAULT.getSerialConsistency(),
-                            ProtocolVersion.V5,
-                            KEYSPACE);
-
     @BeforeClass
     public static void setUp()
     {
@@ -89,54 +75,6 @@ public class ClientResourceLimitsTest extends CQLTester
         ClientResourceLimits.setEndpointLimit(LOW_LIMIT);
     }
 
-    @After
-    public void dropCreatedTable()
-    {
-        try
-        {
-            QueryProcessor.executeOnceInternal("DROP TABLE " + KEYSPACE + 
".atable");
-        }
-        catch (Throwable t)
-        {
-            // ignore
-        }
-    }
-
-    @SuppressWarnings("resource")
-    private SimpleClient client(boolean throwOnOverload)
-    {
-        try
-        {
-            return SimpleClient.builder(nativeAddr.getHostAddress(), 
nativePort)
-                               .protocolVersion(ProtocolVersion.V5)
-                               .useBeta()
-                               .build()
-                               .connect(false, throwOnOverload);
-        }
-        catch (IOException e)
-        {
-            throw new RuntimeException("Error initializing client", e);
-        }
-    }
-
-    @SuppressWarnings({"resource", "SameParameterValue"})
-    private SimpleClient client(boolean throwOnOverload, int 
largeMessageThreshold)
-    {
-        try
-        {
-            return SimpleClient.builder(nativeAddr.getHostAddress(), 
nativePort)
-                               .protocolVersion(ProtocolVersion.V5)
-                               .useBeta()
-                               .largeMessageThreshold(largeMessageThreshold)
-                               .build()
-                               .connect(false, throwOnOverload);
-        }
-        catch (IOException e)
-        {
-           throw new RuntimeException("Error initializing client", e);
-        }
-    }
-
     @Test
     public void testQueryExecutionWithThrowOnOverload()
     {
@@ -153,10 +91,8 @@ public class ClientResourceLimitsTest extends CQLTester
     {
         try (SimpleClient client = client(throwOnOverload))
         {
-            QueryMessage queryMessage = new QueryMessage("CREATE TABLE atable 
(pk int PRIMARY KEY, v text)",
-                                                         V5_DEFAULT_OPTIONS);
-            client.execute(queryMessage);
-            queryMessage = new QueryMessage("SELECT * FROM atable", 
V5_DEFAULT_OPTIONS);
+            createTable(client);
+            QueryMessage queryMessage = new QueryMessage("SELECT * FROM 
atable", queryOptions());
             client.execute(queryMessage);
         }
     }
@@ -189,7 +125,7 @@ public class ClientResourceLimitsTest extends CQLTester
         {
             // The first query does not trigger backpressure/pause the 
connection:
             QueryMessage queryMessage = 
-                    new QueryMessage("CREATE TABLE atable (pk int PRIMARY KEY, 
v text)", V5_DEFAULT_OPTIONS);
+                    new QueryMessage("CREATE TABLE atable (pk int PRIMARY KEY, 
v text)", queryOptions());
             Message.Response belowThresholdResponse = 
client.execute(queryMessage);
             assertEquals(0, getPausedConnectionsGauge().getValue().intValue());
             assertNoWarningContains(belowThresholdResponse, "bytes in flight");
@@ -261,7 +197,11 @@ public class ClientResourceLimitsTest extends CQLTester
     {
         // Bump the per-endpoint limit to make sure we exhaust the global
         ClientResourceLimits.setEndpointLimit(HIGH_LIMIT);
-        testOverloadedException(() -> client(true, Ints.checkedCast(LOW_LIMIT 
/ 2)));
+        // test message = 2/3 x
+        // emulated concurrent message = 2/3 x
+        // test message + emulated concurrent message = 4/3 x > x set as a 
global limit
+        emulateInFlightConcurrentMessage(LOW_LIMIT * 2 / 3);
+        testOverloadedException(() -> client(true, Ints.checkedCast(LOW_LIMIT 
/ 2)), LOW_LIMIT * 2 / 3);
     }
 
     @Test
@@ -269,18 +209,25 @@ public class ClientResourceLimitsTest extends CQLTester
     {
         // Make sure we can only exceed the per-endpoint limit
         ClientResourceLimits.setGlobalLimit(HIGH_LIMIT);
-        testOverloadedException(() -> client(true, Ints.checkedCast(LOW_LIMIT 
/ 2)));
+        // test message = 2/3 x
+        // emulated concurrent message = 2/3 x
+        // test message + emulated concurrent message = 4/3 x > x set as an 
endpoint limit
+        emulateInFlightConcurrentMessage(LOW_LIMIT * 2 / 3);
+        testOverloadedException(() -> client(true, Ints.checkedCast(LOW_LIMIT 
/ 2)), LOW_LIMIT * 2 / 3);
     }
 
     private void testOverloadedException(Supplier<SimpleClient> clientSupplier)
+    {
+        testOverloadedException(clientSupplier, LOW_LIMIT * 2);
+    }
+
+    private void testOverloadedException(Supplier<SimpleClient> 
clientSupplier, long limit)
     {
         try (SimpleClient client = clientSupplier.get())
         {
-            QueryMessage queryMessage = new QueryMessage("CREATE TABLE atable 
(pk int PRIMARY KEY, v text)",
-                                                         V5_DEFAULT_OPTIONS);
-            client.execute(queryMessage);
+            createTable(client);
 
-            queryMessage = queryMessage();
+            QueryMessage queryMessage = queryMessage(limit);
             try
             {
                 client.execute(queryMessage);
@@ -295,11 +242,7 @@ public class ClientResourceLimitsTest extends CQLTester
 
     private QueryMessage queryMessage()
     {
-        StringBuilder query = new StringBuilder("INSERT INTO atable (pk, v) 
VALUES (1, '");
-        for (int i=0; i < LOW_LIMIT * 2; i++)
-            query.append('a');
-        query.append("')");
-        return new QueryMessage(query.toString(), V5_DEFAULT_OPTIONS);
+        return queryMessage(LOW_LIMIT * 2);
     }
 
     @Test
@@ -341,7 +284,7 @@ public class ClientResourceLimitsTest extends CQLTester
             VirtualKeyspaceRegistry.instance.register(new 
VirtualKeyspace(table, ImmutableList.of(vt1)));
 
             final QueryMessage queryMessage = new 
QueryMessage(String.format("SELECT * FROM %s.%s", table, table),
-                                                               
V5_DEFAULT_OPTIONS);
+                                                               queryOptions());
             try
             {
                 Thread tester = new Thread(() -> client.execute(queryMessage));
@@ -367,8 +310,9 @@ public class ClientResourceLimitsTest extends CQLTester
         try
         {
             QueryMessage smallMessage = new QueryMessage(String.format("CREATE 
TABLE %s.atable (pk int PRIMARY KEY, v text)", KEYSPACE),
-                                                         V5_DEFAULT_OPTIONS);
+                                                         queryOptions());
             client.execute(smallMessage);
+            createTable(client);
             try
             {
                 client.execute(queryMessage());
diff --git a/test/unit/org/apache/cassandra/transport/MessageSizeLimitTest.java 
b/test/unit/org/apache/cassandra/transport/MessageSizeLimitTest.java
new file mode 100644
index 0000000000..8b7b2999f4
--- /dev/null
+++ b/test/unit/org/apache/cassandra/transport/MessageSizeLimitTest.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.transport;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.exceptions.InvalidRequestException;
+import org.apache.cassandra.net.FrameEncoder;
+import org.apache.cassandra.transport.messages.QueryMessage;
+import org.assertj.core.api.Assertions;
+
+public class MessageSizeLimitTest extends NativeProtocolLimitsTestBase
+{
+    private static final int MAX_CQL_MESSAGE_SIZE = 
FrameEncoder.Payload.MAX_SIZE * 3;
+    private static final int TOO_BIG_MESSAGE_SIZE = MAX_CQL_MESSAGE_SIZE * 2;
+    private static final int NORMAL_MESSAGE_SIZE = MAX_CQL_MESSAGE_SIZE - 500;
+
+    @BeforeClass
+    public static void setUp()
+    {
+        DatabaseDescriptor.setNativeTransportReceiveQueueCapacityInBytes(1);
+        
DatabaseDescriptor.setNativeTransportMaxRequestDataInFlightPerIpInBytes(MAX_CQL_MESSAGE_SIZE);
+        
DatabaseDescriptor.setNativeTransportConcurrentRequestDataInFlightInBytes(MAX_CQL_MESSAGE_SIZE);
+        
DatabaseDescriptor.setNativeTransportMaxMessageSizeInBytes(MAX_CQL_MESSAGE_SIZE);
+        requireNetwork();
+    }
+
+    @Before
+    public void setLimits()
+    {
+        ClientResourceLimits.setGlobalLimit(MAX_CQL_MESSAGE_SIZE);
+        ClientResourceLimits.setEndpointLimit(MAX_CQL_MESSAGE_SIZE);
+    }
+
+    @Test
+    public void sendMessageWithSizeMoreThanMaxMessageSize()
+    {
+        runClientLogic((client) ->
+               {
+                   QueryMessage tooBigQueryMessage = 
queryMessage(TOO_BIG_MESSAGE_SIZE);
+                   Assertions.assertThatThrownBy(() -> 
client.execute(tooBigQueryMessage))
+                             
.hasCauseInstanceOf(InvalidRequestException.class);
+                   // InvalidRequestException: CQL Message of size 524362 
bytes exceeds allowed maximum of 262144 bytes
+
+                   // we send one more message to check that the server 
continues to process new messages in the opened connection
+                   QueryMessage queryMessage = 
queryMessage(NORMAL_MESSAGE_SIZE);
+                   client.execute(queryMessage);
+               }
+        );
+    }
+
+    @Test(timeout = 30_000)
+    public void checkThatThereIsNoStarvationForMultiFrameMessages() throws 
InterruptedException
+    {
+        runClientLogic((client) -> {}, true); // to create table
+        AtomicInteger completedSuccessfully = new AtomicInteger(0);
+        int threadsCount = 2;
+        List<Thread> threads = new ArrayList<>();
+        for (int i = 0; i < threadsCount; i++)
+        {
+            threads.add(new Thread(() -> runClientLogic((client) -> {
+                    sendMessages(client, 100, NORMAL_MESSAGE_SIZE);
+                    completedSuccessfully.incrementAndGet();
+                }, false))
+            );
+        }
+        for (Thread thread : threads)
+            thread.start();
+
+        for (Thread thread : threads)
+            thread.join();
+
+        Assert.assertEquals("not all messages were sent successfully by all 
threads",
+                            threadsCount, completedSuccessfully.get());
+    }
+
+    private void sendMessages(SimpleClient client, int messagesCount, int 
messageSize)
+    {
+        for (int i = 0; i < messagesCount; i++)
+        {
+            QueryMessage queryMessage1 = queryMessage(messageSize);
+            client.execute(queryMessage1);
+        }
+    }
+
+    @Test
+    public void sendMessageWithSizeBelowLimit()
+    {
+        runClientLogic((client) ->
+               {
+                   QueryMessage queryMessage = 
queryMessage(NORMAL_MESSAGE_SIZE);
+                   client.execute(queryMessage);
+
+                   // run one more time, to validate that the connection is 
still alive
+                   queryMessage = queryMessage(NORMAL_MESSAGE_SIZE);
+                   client.execute(queryMessage);
+               }
+        );
+    }
+}
diff --git 
a/test/unit/org/apache/cassandra/transport/NativeProtocolLimitsTestBase.java 
b/test/unit/org/apache/cassandra/transport/NativeProtocolLimitsTestBase.java
new file mode 100644
index 0000000000..221bf97b23
--- /dev/null
+++ b/test/unit/org/apache/cassandra/transport/NativeProtocolLimitsTestBase.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.transport;
+
+import java.io.IOException;
+
+import org.junit.After;
+import org.apache.cassandra.cql3.CQLTester;
+import org.apache.cassandra.cql3.QueryOptions;
+import org.apache.cassandra.cql3.QueryProcessor;
+import org.apache.cassandra.net.FrameEncoder;
+import org.apache.cassandra.transport.messages.QueryMessage;
+import org.apache.cassandra.utils.FBUtilities;
+
+public abstract class NativeProtocolLimitsTestBase extends CQLTester
+{
+    protected final ProtocolVersion version;
+
+    protected long emulatedUsedCapacity;
+
+    public NativeProtocolLimitsTestBase()
+    {
+        this(ProtocolVersion.V5);
+    }
+
+    public NativeProtocolLimitsTestBase(ProtocolVersion version)
+    {
+        this.version = version;
+    }
+
+    @After
+    public void dropCreatedTable()
+    {
+        if (emulatedUsedCapacity > 0)
+        {
+            releaseEmulatedCapacity(emulatedUsedCapacity);
+        }
+        try
+        {
+            QueryProcessor.executeOnceInternal("DROP TABLE " + KEYSPACE + 
".atable");
+        }
+        catch (Throwable t)
+        {
+            // ignore
+        }
+    }
+
+    public QueryOptions queryOptions()
+    {
+        return QueryOptions.create(QueryOptions.DEFAULT.getConsistency(),
+                                   QueryOptions.DEFAULT.getValues(),
+                                   QueryOptions.DEFAULT.skipMetadata(),
+                                   QueryOptions.DEFAULT.getPageSize(),
+                                   QueryOptions.DEFAULT.getPagingState(),
+                                   QueryOptions.DEFAULT.getSerialConsistency(),
+                                   version,
+                                   KEYSPACE);
+    }
+
+    public SimpleClient client()
+    {
+        return client(false);
+    }
+
+    @SuppressWarnings({"resource", "SameParameterValue"})
+    public SimpleClient client(boolean throwOnOverload)
+    {
+        return client(throwOnOverload, FrameEncoder.Payload.MAX_SIZE);
+    }
+
+    @SuppressWarnings({"resource", "SameParameterValue"})
+    public SimpleClient client(boolean throwOnOverload, int 
largeMessageThreshold)
+    {
+        try
+        {
+            return SimpleClient.builder(nativeAddr.getHostAddress(), 
nativePort)
+                               .protocolVersion(version)
+                               .useBeta()
+                               .largeMessageThreshold(largeMessageThreshold)
+                               .build()
+                               .connect(false, throwOnOverload);
+        }
+        catch (IOException e)
+        {
+            throw new RuntimeException("Error initializing client", e);
+        }
+    }
+
+    public void runClientLogic(ClientLogic clientLogic)
+    {
+        runClientLogic(clientLogic, true);
+    }
+
+    public void runClientLogic(ClientLogic clientLogic, boolean createTable)
+    {
+        try (SimpleClient client = client())
+        {
+            if (createTable)
+                createTable(client);
+            clientLogic.run(client);
+        }
+    }
+
+    public void createTable(SimpleClient client)
+    {
+        QueryMessage queryMessage = new QueryMessage("CREATE TABLE IF NOT 
EXISTS " +
+                                                     KEYSPACE + ".atable (pk 
int PRIMARY KEY, v text)",
+                                                     queryOptions());
+        client.execute(queryMessage);
+    }
+
+    public void doTest(ClientLogic testLogic)
+    {
+        try (SimpleClient client = client())
+        {
+            testLogic.run(client);
+        }
+    }
+    public interface ClientLogic
+    {
+        void run(SimpleClient simpleClient);
+    }
+
+    public QueryMessage queryMessage(long valueSize)
+    {
+        StringBuilder query = new StringBuilder("INSERT INTO " + KEYSPACE + 
".atable (pk, v) VALUES (1, '");
+        for (int i = 0; i < valueSize; i++)
+            query.append('a');
+        query.append("')");
+        return new QueryMessage(query.toString(), queryOptions());
+    }
+
+    protected void emulateInFlightConcurrentMessage(long length)
+    {
+        ClientResourceLimits.Allocator allocator = 
ClientResourceLimits.getAllocatorForEndpoint(FBUtilities.getJustLocalAddress());
+        ClientResourceLimits.ResourceProvider resourceProvider = new 
ClientResourceLimits.ResourceProvider.Default(allocator);
+        resourceProvider.globalLimit().allocate(length);
+        resourceProvider.endpointLimit().allocate(length);
+        emulatedUsedCapacity += length;
+    }
+
+    protected void releaseEmulatedCapacity(long length)
+    {
+        ClientResourceLimits.Allocator allocator = 
ClientResourceLimits.getAllocatorForEndpoint(FBUtilities.getJustLocalAddress());
+        ClientResourceLimits.ResourceProvider resourceProvider = new 
ClientResourceLimits.ResourceProvider.Default(allocator);
+        resourceProvider.globalLimit().release(length);
+        resourceProvider.endpointLimit().release(length);
+    }
+}
diff --git a/test/unit/org/apache/cassandra/transport/RateLimitingTest.java 
b/test/unit/org/apache/cassandra/transport/RateLimitingTest.java
index 0b3fb34a79..1e8e16a662 100644
--- a/test/unit/org/apache/cassandra/transport/RateLimitingTest.java
+++ b/test/unit/org/apache/cassandra/transport/RateLimitingTest.java
@@ -28,7 +28,9 @@ import java.util.stream.Collectors;
 
 import com.codahale.metrics.Meter;
 import com.google.common.base.Ticker;
+
 import org.awaitility.Awaitility;
+
 import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -36,12 +38,9 @@ import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
 import org.apache.cassandra.config.DatabaseDescriptor;
-import org.apache.cassandra.cql3.CQLTester;
-import org.apache.cassandra.cql3.QueryOptions;
 import org.apache.cassandra.exceptions.OverloadedException;
 import org.apache.cassandra.metrics.CassandraMetricsRegistry;
 import org.apache.cassandra.service.StorageService;
-import org.apache.cassandra.transport.messages.QueryMessage;
 import org.apache.cassandra.utils.Throwables;
 
 import static org.junit.Assert.assertEquals;
@@ -54,7 +53,7 @@ import static 
org.apache.cassandra.transport.ProtocolVersion.V4;
 
 @SuppressWarnings("UnstableApiUsage")
 @RunWith(Parameterized.class)
-public class RateLimitingTest extends CQLTester
+public class RateLimitingTest extends NativeProtocolLimitsTestBase
 {
     public static final String BACKPRESSURE_WARNING_SNIPPET = "Request 
breached global limit";
     
@@ -63,9 +62,6 @@ public class RateLimitingTest extends CQLTester
 
     private static final long MAX_LONG_CONFIG_VALUE = Long.MAX_VALUE - 1;
 
-    @Parameterized.Parameter
-    public ProtocolVersion version;
-
     @Parameterized.Parameters(name="{0}")
     public static Collection<Object[]> versions()
     {
@@ -74,6 +70,11 @@ public class RateLimitingTest extends CQLTester
                                         .collect(Collectors.toList());
     }
 
+    public RateLimitingTest(ProtocolVersion version)
+    {
+        super(version);
+    }
+
     private AtomicLong tick;
     private Ticker ticker;
 
@@ -105,6 +106,7 @@ public class RateLimitingTest extends CQLTester
         ClientResourceLimits.setGlobalLimit(MAX_LONG_CONFIG_VALUE);
     }
 
+
     @Test
     public void shouldThrowOnOverloadSmallMessages() throws Exception
     {
@@ -147,15 +149,19 @@ public class RateLimitingTest extends CQLTester
 
     private void testBytesInFlightOverload(int payloadSize) throws Exception
     {
-        try (SimpleClient client = client().connect(false, true))
+        int emulatedConcurrentMessageSize = payloadSize * 3 / 2;
+        try (SimpleClient client = client(true, LARGE_PAYLOAD_THRESHOLD_BYTES))
         {
             
StorageService.instance.setNativeTransportRateLimitingEnabled(false);
-            QueryMessage queryMessage = new QueryMessage("CREATE TABLE IF NOT 
EXISTS " + KEYSPACE + ".atable (pk int PRIMARY KEY, v text)", queryOptions());
-            client.execute(queryMessage);
+            createTable(client);
 
             
StorageService.instance.setNativeTransportRateLimitingEnabled(true);
             
ClientResourceLimits.GLOBAL_REQUEST_LIMITER.setRate(OVERLOAD_PERMITS_PER_SECOND,
 ticker);
-            ClientResourceLimits.setGlobalLimit(1);
+            // test message = 1x
+            // emulated concurrent message = 1.5x
+            // test message + emulated concurrent message = 2.5x > 2x set as a 
global limit
+            ClientResourceLimits.setGlobalLimit(payloadSize * 2);
+            emulateInFlightConcurrentMessage(emulatedConcurrentMessageSize);
 
             try
             {
@@ -170,18 +176,17 @@ public class RateLimitingTest extends CQLTester
         finally
         {
             // Sanity check bytes in flight limiter.
-            Awaitility.await().untilAsserted(() -> assertEquals(0, 
ClientResourceLimits.getCurrentGlobalUsage()));
+            Awaitility.await().untilAsserted(() -> 
assertEquals(emulatedConcurrentMessageSize, 
ClientResourceLimits.getCurrentGlobalUsage()));
             
StorageService.instance.setNativeTransportRateLimitingEnabled(false);
         }
     }
 
     private void testOverload(int payloadSize, boolean throwOnOverload) throws 
Exception
     {
-        try (SimpleClient client = client().connect(false, throwOnOverload))
+        try (SimpleClient client = client(throwOnOverload, 
LARGE_PAYLOAD_THRESHOLD_BYTES))
         {
             
StorageService.instance.setNativeTransportRateLimitingEnabled(false);
-            QueryMessage queryMessage = new QueryMessage("CREATE TABLE IF NOT 
EXISTS " + KEYSPACE + ".atable (pk int PRIMARY KEY, v text)", queryOptions());
-            client.execute(queryMessage);
+            createTable(client);
 
             
StorageService.instance.setNativeTransportRateLimitingEnabled(true);
             
ClientResourceLimits.GLOBAL_REQUEST_LIMITER.setRate(OVERLOAD_PERMITS_PER_SECOND,
 ticker);
@@ -286,40 +291,6 @@ public class RateLimitingTest extends CQLTester
         assertEquals(dispatchedPrior + 2, 
getRequestDispatchedMeter().getCount());
     }
 
-    private QueryMessage queryMessage(int length)
-    {
-        StringBuilder query = new StringBuilder("INSERT INTO " + KEYSPACE + 
".atable (pk, v) VALUES (1, '");
-        
-        for (int i = 0; i < length; i++)
-        {
-            query.append('a');
-        }
-        
-        query.append("')");
-        return new QueryMessage(query.toString(), queryOptions());
-    }
-
-    private SimpleClient client()
-    {
-        return SimpleClient.builder(nativeAddr.getHostAddress(), nativePort)
-                           .protocolVersion(version)
-                           .useBeta()
-                           
.largeMessageThreshold(LARGE_PAYLOAD_THRESHOLD_BYTES)
-                           .build();
-    }
-
-    private QueryOptions queryOptions()
-    {
-        return QueryOptions.create(QueryOptions.DEFAULT.getConsistency(),
-                                   QueryOptions.DEFAULT.getValues(),
-                                   QueryOptions.DEFAULT.skipMetadata(),
-                                   QueryOptions.DEFAULT.getPageSize(),
-                                   QueryOptions.DEFAULT.getPagingState(),
-                                   QueryOptions.DEFAULT.getSerialConsistency(),
-                                   version,
-                                   KEYSPACE);
-    }
-
     protected static Meter getRequestDispatchedMeter()
     {
         String metricName = 
"org.apache.cassandra.metrics.Client.RequestDispatched";


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to