Repository: cassandra
Updated Branches:
  refs/heads/trunk d8c451923 -> 298416a74


Incomplete handling of exceptions when decoding incoming messages

patch by jasobrown; reviewed by Dinesh Joshi for CASSANDRA-14574


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/298416a7
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/298416a7
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/298416a7

Branch: refs/heads/trunk
Commit: 298416a7445aa50874caebc779ca3094b32f3e31
Parents: d8c4519
Author: Jason Brown <jasedbr...@gmail.com>
Authored: Wed Jul 18 13:47:22 2018 -0700
Committer: Jason Brown <jasedbr...@gmail.com>
Committed: Fri Aug 17 05:54:37 2018 -0700

----------------------------------------------------------------------
 CHANGES.txt                                     |   1 +
 .../net/async/BaseMessageInHandler.java         |  61 ++++++++-
 .../cassandra/net/async/MessageInHandler.java   | 121 +++++++---------
 .../net/async/MessageInHandlerPre40.java        | 137 +++++++++----------
 .../test/microbench/MessageOutBench.java        |   6 +-
 .../net/async/MessageInHandlerTest.java         |  65 ++++++++-
 6 files changed, 238 insertions(+), 153 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/298416a7/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index d2970a4..0e671b0 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 4.0
+ * Incomplete handling of exceptions when decoding incoming messages 
(CASSANDRA-14574)
  * Add diagnostic events for user audit logging (CASSANDRA-13668)
  * Allow retrieving diagnostic events via JMX (CASSANDRA-14435)
  * Add base classes for diagnostic events (CASSANDRA-13457)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/298416a7/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java 
b/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java
index 7314999..2f2a973 100644
--- a/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java
+++ b/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java
@@ -26,7 +26,6 @@ import java.util.Map;
 import java.util.function.BiConsumer;
 
 import com.google.common.annotations.VisibleForTesting;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -40,6 +39,14 @@ import org.apache.cassandra.net.MessageIn;
 import org.apache.cassandra.net.MessagingService;
 import org.apache.cassandra.net.ParameterType;
 
+/**
+ * Parses out individual messages from the incoming buffers. Each message, 
both header and payload, is incrementally built up
+ * from the available input data, then passed to the {@link #messageConsumer}.
+ *
+ * Note: this class derives from {@link ByteToMessageDecoder} to take 
advantage of the {@link ByteToMessageDecoder.Cumulator}
+ * behavior across {@link #decode(ChannelHandlerContext, ByteBuf, List)} 
invocations. That way we don't have to maintain
+ * the not-fully consumed {@link ByteBuf}s.
+ */
 public abstract class BaseMessageInHandler extends ByteToMessageDecoder
 {
     public static final Logger logger = 
LoggerFactory.getLogger(BaseMessageInHandler.class);
@@ -52,7 +59,8 @@ public abstract class BaseMessageInHandler extends 
ByteToMessageDecoder
         READ_PARAMETERS_SIZE,
         READ_PARAMETERS_DATA,
         READ_PAYLOAD_SIZE,
-        READ_PAYLOAD
+        READ_PAYLOAD,
+        CLOSED
     }
 
     /**
@@ -77,6 +85,8 @@ public abstract class BaseMessageInHandler extends 
ByteToMessageDecoder
     final InetAddressAndPort peer;
     final int messagingVersion;
 
+    protected State state;
+
     public BaseMessageInHandler(InetAddressAndPort peer, int messagingVersion, 
BiConsumer<MessageIn, Integer> messageConsumer)
     {
         this.peer = peer;
@@ -84,7 +94,36 @@ public abstract class BaseMessageInHandler extends 
ByteToMessageDecoder
         this.messageConsumer = messageConsumer;
     }
 
-    public abstract void decode(ChannelHandlerContext ctx, ByteBuf in, 
List<Object> out);
+    // redeclared here to make the method public (for testing)
+    @VisibleForTesting
+    public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> 
out) throws Exception
+    {
+        if (state == State.CLOSED)
+        {
+            in.skipBytes(in.readableBytes());
+            return;
+        }
+
+        try
+        {
+            handleDecode(ctx, in, out);
+        }
+        catch (Exception e)
+        {
+            // prevent any future attempts at reading messages from any 
inbound buffers, as we're already in a bad state
+            state = State.CLOSED;
+
+            // force the buffer to appear to be consumed, thereby exiting the 
ByteToMessageDecoder.callDecode() loop,
+            // and other paths in that class, more efficiently
+            in.skipBytes(in.readableBytes());
+
+            // throwing the exception up causes the 
ByteToMessageDecoder.callDecode() loop to exit. if we don't do that,
+            // we'll keep trying to process data out of the last received 
buffer (and it'll be really, really wrong)
+            throw e;
+        }
+    }
+
+    public abstract void handleDecode(ChannelHandlerContext ctx, ByteBuf in, 
List<Object> out) throws Exception;
 
     MessageHeader readFirstChunk(ByteBuf in) throws IOException
     {
@@ -105,11 +144,13 @@ public abstract class BaseMessageInHandler extends 
ByteToMessageDecoder
         if (cause instanceof EOFException)
             logger.trace("eof reading from socket; closing", cause);
         else if (cause instanceof UnknownTableException)
-            logger.warn("Got message from unknown table while reading from 
socket; closing", cause);
+            logger.warn(" Got message from unknown table while reading from 
socket {}[{}]; closing",
+                        ctx.channel().remoteAddress(), ctx.channel().id(), 
cause);
         else if (cause instanceof IOException)
             logger.trace("IOException reading from socket; closing", cause);
         else
-            logger.warn("Unexpected exception caught in inbound channel 
pipeline from " + ctx.channel().remoteAddress(), cause);
+            logger.warn("Unexpected exception caught in inbound channel 
pipeline from {}[{}]",
+                        ctx.channel().remoteAddress(), ctx.channel().id(), 
cause);
 
         ctx.close();
     }
@@ -118,6 +159,7 @@ public abstract class BaseMessageInHandler extends 
ByteToMessageDecoder
     public void channelInactive(ChannelHandlerContext ctx)
     {
         logger.trace("received channel closed message for peer {} on local 
addr {}", ctx.channel().remoteAddress(), ctx.channel().localAddress());
+        state = State.CLOSED;
         ctx.fireChannelInactive();
     }
 
@@ -128,7 +170,7 @@ public abstract class BaseMessageInHandler extends 
ByteToMessageDecoder
     /**
      * A simple struct to hold the message header data as it is being built up.
      */
-    protected static class MessageHeader
+    static class MessageHeader
     {
         int messageId;
         long constructionTime;
@@ -145,4 +187,11 @@ public abstract class BaseMessageInHandler extends 
ByteToMessageDecoder
          */
         int parameterLength;
     }
+
+    // for testing purposes only!!!
+    @VisibleForTesting
+    public State getState()
+    {
+        return state;
+    }
 }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/298416a7/src/java/org/apache/cassandra/net/async/MessageInHandler.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/net/async/MessageInHandler.java 
b/src/java/org/apache/cassandra/net/async/MessageInHandler.java
index eb22e91..0a194d4 100644
--- a/src/java/org/apache/cassandra/net/async/MessageInHandler.java
+++ b/src/java/org/apache/cassandra/net/async/MessageInHandler.java
@@ -32,8 +32,6 @@ import org.slf4j.LoggerFactory;
 
 import io.netty.buffer.ByteBuf;
 import io.netty.channel.ChannelHandlerContext;
-import io.netty.handler.codec.ByteToMessageDecoder;
-import org.apache.cassandra.db.monitoring.ApproximateTime;
 import org.apache.cassandra.io.util.DataInputBuffer;
 import org.apache.cassandra.locator.InetAddressAndPort;
 import org.apache.cassandra.net.MessageIn;
@@ -42,18 +40,12 @@ import org.apache.cassandra.net.ParameterType;
 import org.apache.cassandra.utils.vint.VIntCoding;
 
 /**
- * Parses out individual messages from the incoming buffers. Each message, 
both header and payload, is incrementally built up
- * from the available input data, then passed to the {@link #messageConsumer}.
- *
- * Note: this class derives from {@link ByteToMessageDecoder} to take 
advantage of the {@link ByteToMessageDecoder.Cumulator}
- * behavior across {@link #decode(ChannelHandlerContext, ByteBuf, List)} 
invocations. That way we don't have to maintain
- * the not-fully consumed {@link ByteBuf}s.
+ * Parses incoming messages as per the 4.0 internode messaging protocol.
  */
 public class MessageInHandler extends BaseMessageInHandler
 {
     public static final Logger logger = 
LoggerFactory.getLogger(MessageInHandler.class);
 
-    private State state;
     private MessageHeader messageHeader;
 
     MessageInHandler(InetAddressAndPort peer, int messagingVersion)
@@ -76,77 +68,70 @@ public class MessageInHandler extends BaseMessageInHandler
      * maintains a trivial state machine to remember progress across 
invocations.
      */
     @SuppressWarnings("resource")
-    public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
+    public void handleDecode(ChannelHandlerContext ctx, ByteBuf in, 
List<Object> out) throws Exception
     {
         ByteBufDataInputPlus inputPlus = new ByteBufDataInputPlus(in);
-        try
+        while (true)
         {
-            while (true)
+            switch (state)
             {
-                switch (state)
-                {
-                    case READ_FIRST_CHUNK:
-                        MessageHeader header = readFirstChunk(in);
-                        if (header == null)
-                            return;
-                        header.from = peer;
-                        messageHeader = header;
-                        state = State.READ_VERB;
-                        // fall-through
-                    case READ_VERB:
-                        if (in.readableBytes() < VERB_LENGTH)
-                            return;
-                        messageHeader.verb = 
MessagingService.Verb.fromId(in.readInt());
-                        state = State.READ_PARAMETERS_SIZE;
-                        // fall-through
-                    case READ_PARAMETERS_SIZE:
-                        long length = VIntCoding.readUnsignedVInt(in);
-                        if (length < 0)
-                            return;
-                        messageHeader.parameterLength = (int) length;
-                        messageHeader.parameters = 
messageHeader.parameterLength == 0 ? Collections.emptyMap() : new 
EnumMap<>(ParameterType.class);
-                        state = State.READ_PARAMETERS_DATA;
-                        // fall-through
-                    case READ_PARAMETERS_DATA:
-                        if (messageHeader.parameterLength > 0)
-                        {
-                            if (in.readableBytes() < 
messageHeader.parameterLength)
-                                return;
-                            readParameters(in, inputPlus, 
messageHeader.parameterLength, messageHeader.parameters);
-                        }
-                        state = State.READ_PAYLOAD_SIZE;
-                        // fall-through
-                    case READ_PAYLOAD_SIZE:
-                        length = VIntCoding.readUnsignedVInt(in);
-                        if (length < 0)
-                            return;
-                        messageHeader.payloadSize = (int) length;
-                        state = State.READ_PAYLOAD;
-                        // fall-through
-                    case READ_PAYLOAD:
-                        if (in.readableBytes() < messageHeader.payloadSize)
+                case READ_FIRST_CHUNK:
+                    MessageHeader header = readFirstChunk(in);
+                    if (header == null)
+                        return;
+                    header.from = peer;
+                    messageHeader = header;
+                    state = State.READ_VERB;
+                    // fall-through
+                case READ_VERB:
+                    if (in.readableBytes() < VERB_LENGTH)
+                        return;
+                    messageHeader.verb = 
MessagingService.Verb.fromId(in.readInt());
+                    state = State.READ_PARAMETERS_SIZE;
+                    // fall-through
+                case READ_PARAMETERS_SIZE:
+                    long length = VIntCoding.readUnsignedVInt(in);
+                    if (length < 0)
+                        return;
+                    messageHeader.parameterLength = (int) length;
+                    messageHeader.parameters = messageHeader.parameterLength 
== 0 ? Collections.emptyMap() : new EnumMap<>(ParameterType.class);
+                    state = State.READ_PARAMETERS_DATA;
+                    // fall-through
+                case READ_PARAMETERS_DATA:
+                    if (messageHeader.parameterLength > 0)
+                    {
+                        if (in.readableBytes() < messageHeader.parameterLength)
                             return;
+                        readParameters(in, inputPlus, 
messageHeader.parameterLength, messageHeader.parameters);
+                    }
+                    state = State.READ_PAYLOAD_SIZE;
+                    // fall-through
+                case READ_PAYLOAD_SIZE:
+                    length = VIntCoding.readUnsignedVInt(in);
+                    if (length < 0)
+                        return;
+                    messageHeader.payloadSize = (int) length;
+                    state = State.READ_PAYLOAD;
+                    // fall-through
+                case READ_PAYLOAD:
+                    if (in.readableBytes() < messageHeader.payloadSize)
+                        return;
 
-                        // TODO consider deserializing the message not on the 
event loop
-                        MessageIn<Object> messageIn = 
MessageIn.read(inputPlus, messagingVersion,
+                    // TODO consider deserializing the message not on the 
event loop
+                    MessageIn<Object> messageIn = MessageIn.read(inputPlus, 
messagingVersion,
                                                                      
messageHeader.messageId, messageHeader.constructionTime, messageHeader.from,
                                                                      
messageHeader.payloadSize, messageHeader.verb, messageHeader.parameters);
 
-                        if (messageIn != null)
-                            messageConsumer.accept(messageIn, 
messageHeader.messageId);
+                    if (messageIn != null)
+                        messageConsumer.accept(messageIn, 
messageHeader.messageId);
 
-                        state = State.READ_FIRST_CHUNK;
-                        messageHeader = null;
-                        break;
-                    default:
-                        throw new IllegalStateException("unknown/unhandled 
state: " + state);
-                }
+                    state = State.READ_FIRST_CHUNK;
+                    messageHeader = null;
+                    break;
+                default:
+                    throw new IllegalStateException("unknown/unhandled state: 
" + state);
             }
         }
-        catch (Exception e)
-        {
-            exceptionCaught(ctx, e);
-        }
     }
 
     private void readParameters(ByteBuf in, ByteBufDataInputPlus inputPlus, 
int parameterLength, Map<ParameterType, Object> parameters) throws IOException

http://git-wip-us.apache.org/repos/asf/cassandra/blob/298416a7/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java 
b/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java
index fb19b43..f5b6fc4 100644
--- a/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java
+++ b/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java
@@ -39,6 +39,9 @@ import org.apache.cassandra.net.MessageIn;
 import org.apache.cassandra.net.MessagingService;
 import org.apache.cassandra.net.ParameterType;
 
+/**
+ * Parses incoming messages as per the pre-4.0 internode messaging protocol.
+ */
 public class MessageInHandlerPre40 extends BaseMessageInHandler
 {
     public static final Logger logger = 
LoggerFactory.getLogger(MessageInHandlerPre40.class);
@@ -47,7 +50,6 @@ public class MessageInHandlerPre40 extends 
BaseMessageInHandler
     static final int PARAMETERS_VALUE_SIZE_LENGTH = Integer.BYTES;
     static final int PAYLOAD_SIZE_LENGTH = Integer.BYTES;
 
-    private State state;
     private MessageHeader messageHeader;
 
     MessageInHandlerPre40(InetAddressAndPort peer, int messagingVersion)
@@ -70,83 +72,76 @@ public class MessageInHandlerPre40 extends 
BaseMessageInHandler
      * maintains a trivial state machine to remember progress across 
invocations.
      */
     @SuppressWarnings("resource")
-    public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
+    public void handleDecode(ChannelHandlerContext ctx, ByteBuf in, 
List<Object> out) throws Exception
     {
         ByteBufDataInputPlus inputPlus = new ByteBufDataInputPlus(in);
-        try
+        while (true)
         {
-            while (true)
+            switch (state)
             {
-                switch (state)
-                {
-                    case READ_FIRST_CHUNK:
-                        MessageHeader header = readFirstChunk(in);
-                        if (header == null)
-                            return;
-                        messageHeader = header;
-                        state = State.READ_IP_ADDRESS;
-                        // fall-through
-                    case READ_IP_ADDRESS:
-                        // unfortunately, this assumes knowledge of how 
CompactEndpointSerializationHelper serializes data (the first byte is the size).
-                        // first, check that we can actually read the size 
byte, then check if we can read that number of bytes.
-                        // the "+ 1" is to make sure we have the size byte in 
addition to the serialized IP addr count of bytes in the buffer.
-                        int readableBytes = in.readableBytes();
-                        if (readableBytes < 1 || readableBytes < 
in.getByte(in.readerIndex()) + 1)
-                            return;
-                        messageHeader.from = 
CompactEndpointSerializationHelper.instance.deserialize(inputPlus, 
messagingVersion);
-                        state = State.READ_VERB;
-                        // fall-through
-                    case READ_VERB:
-                        if (in.readableBytes() < VERB_LENGTH)
-                            return;
-                        messageHeader.verb = 
MessagingService.Verb.fromId(in.readInt());
-                        state = State.READ_PARAMETERS_SIZE;
-                        // fall-through
-                    case READ_PARAMETERS_SIZE:
-                        if (in.readableBytes() < PARAMETERS_SIZE_LENGTH)
-                            return;
-                        messageHeader.parameterLength = in.readInt();
-                        messageHeader.parameters = 
messageHeader.parameterLength == 0 ? Collections.emptyMap() : new 
EnumMap<>(ParameterType.class);
-                        state = State.READ_PARAMETERS_DATA;
-                        // fall-through
-                    case READ_PARAMETERS_DATA:
-                        if (messageHeader.parameterLength > 0)
-                        {
-                            if (!readParameters(in, inputPlus, 
messageHeader.parameterLength, messageHeader.parameters))
-                                return;
-                        }
-                        state = State.READ_PAYLOAD_SIZE;
-                        // fall-through
-                    case READ_PAYLOAD_SIZE:
-                        if (in.readableBytes() < PAYLOAD_SIZE_LENGTH)
-                            return;
-                        messageHeader.payloadSize = in.readInt();
-                        state = State.READ_PAYLOAD;
-                        // fall-through
-                    case READ_PAYLOAD:
-                        if (in.readableBytes() < messageHeader.payloadSize)
+                case READ_FIRST_CHUNK:
+                    MessageHeader header = readFirstChunk(in);
+                    if (header == null)
+                        return;
+                    messageHeader = header;
+                    state = State.READ_IP_ADDRESS;
+                    // fall-through
+                case READ_IP_ADDRESS:
+                    // unfortunately, this assumes knowledge of how 
CompactEndpointSerializationHelper serializes data (the first byte is the size).
+                    // first, check that we can actually read the size byte, 
then check if we can read that number of bytes.
+                    // the "+ 1" is to make sure we have the size byte in 
addition to the serialized IP addr count of bytes in the buffer.
+                    int readableBytes = in.readableBytes();
+                    if (readableBytes < 1 || readableBytes < 
in.getByte(in.readerIndex()) + 1)
+                        return;
+                    messageHeader.from = 
CompactEndpointSerializationHelper.instance.deserialize(inputPlus, 
messagingVersion);
+                    state = State.READ_VERB;
+                    // fall-through
+                case READ_VERB:
+                    if (in.readableBytes() < VERB_LENGTH)
+                        return;
+                    messageHeader.verb = 
MessagingService.Verb.fromId(in.readInt());
+                    state = State.READ_PARAMETERS_SIZE;
+                    // fall-through
+                case READ_PARAMETERS_SIZE:
+                    if (in.readableBytes() < PARAMETERS_SIZE_LENGTH)
+                        return;
+                    messageHeader.parameterLength = in.readInt();
+                    messageHeader.parameters = messageHeader.parameterLength 
== 0 ? Collections.emptyMap() : new EnumMap<>(ParameterType.class);
+                    state = State.READ_PARAMETERS_DATA;
+                    // fall-through
+                case READ_PARAMETERS_DATA:
+                    if (messageHeader.parameterLength > 0)
+                    {
+                        if (!readParameters(in, inputPlus, 
messageHeader.parameterLength, messageHeader.parameters))
                             return;
-
-                        // TODO consider deserailizing the messge not on the 
event loop
-                        MessageIn<Object> messageIn = 
MessageIn.read(inputPlus, messagingVersion,
-                                                                     
messageHeader.messageId, messageHeader.constructionTime, messageHeader.from,
-                                                                     
messageHeader.payloadSize, messageHeader.verb, messageHeader.parameters);
-
-                        if (messageIn != null)
-                            messageConsumer.accept(messageIn, 
messageHeader.messageId);
-
-                        state = State.READ_FIRST_CHUNK;
-                        messageHeader = null;
-                        break;
-                    default:
-                        throw new IllegalStateException("unknown/unhandled 
state: " + state);
-                }
+                    }
+                    state = State.READ_PAYLOAD_SIZE;
+                    // fall-through
+                case READ_PAYLOAD_SIZE:
+                    if (in.readableBytes() < PAYLOAD_SIZE_LENGTH)
+                        return;
+                    messageHeader.payloadSize = in.readInt();
+                    state = State.READ_PAYLOAD;
+                    // fall-through
+                case READ_PAYLOAD:
+                    if (in.readableBytes() < messageHeader.payloadSize)
+                        return;
+
+                    // TODO consider deserailizing the messge not on the event 
loop
+                    MessageIn<Object> messageIn = MessageIn.read(inputPlus, 
messagingVersion,
+                                                                 
messageHeader.messageId, messageHeader.constructionTime, messageHeader.from,
+                                                                 
messageHeader.payloadSize, messageHeader.verb, messageHeader.parameters);
+
+                    if (messageIn != null)
+                        messageConsumer.accept(messageIn, 
messageHeader.messageId);
+
+                    state = State.READ_FIRST_CHUNK;
+                    messageHeader = null;
+                    break;
+                default:
+                    throw new IllegalStateException("unknown/unhandled state: 
" + state);
             }
         }
-        catch (Exception e)
-        {
-            exceptionCaught(ctx, e);
-        }
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/cassandra/blob/298416a7/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java
----------------------------------------------------------------------
diff --git 
a/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java 
b/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java
index 43b0c16..2aec668 100644
--- a/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java
+++ b/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java
@@ -99,12 +99,12 @@ public class MessageOutBench
     }
 
     @Benchmark
-    public int serialize40() throws IOException
+    public int serialize40() throws Exception
     {
         return serialize(MessagingService.VERSION_40, handler40);
     }
 
-    private int serialize(int messagingVersion, BaseMessageInHandler handler) 
throws IOException
+    private int serialize(int messagingVersion, BaseMessageInHandler handler) 
throws Exception
     {
         buf.resetReaderIndex();
         buf.resetWriterIndex();
@@ -118,7 +118,7 @@ public class MessageOutBench
     }
 
     @Benchmark
-    public int serializePre40() throws IOException
+    public int serializePre40() throws Exception
     {
         return serialize(MessagingService.VERSION_30, handlerPre40);
     }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/298416a7/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java 
b/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java
index 8deb6dc..5997861 100644
--- a/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java
+++ b/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java
@@ -26,6 +26,7 @@ import java.util.Collections;
 import java.util.EnumMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.UUID;
 import java.util.function.BiConsumer;
 
@@ -148,7 +149,7 @@ public class MessageInHandlerTest
         MessageOut msgOut = new MessageOut<>(addr, MessagingService.Verb.ECHO, 
null, null, ImmutableList.of(), SMALL_MESSAGE);
         for (Map.Entry<ParameterType, Object> param : parameters.entrySet())
             msgOut = msgOut.withParameter(param.getKey(), param.getValue());
-        serialize(msgOut);
+        serialize(msgOut, MSG_ID);
 
         MessageInWrapper wrapper = new MessageInWrapper();
         BaseMessageInHandler handler = getHandler(addr, messagingVersion, 
wrapper.messageConsumer);
@@ -164,11 +165,12 @@ public class MessageInHandlerTest
         return wrapper;
     }
 
-    private void serialize(MessageOut msgOut) throws IOException
+    private void serialize(MessageOut msgOut, int id) throws IOException
     {
-        buf = Unpooled.buffer(1024, 1024); // 1k should be enough for 
everybody!
+        if (buf == null)
+            buf = Unpooled.buffer(1024, 1024); // 1k should be enough for 
everybody!
         buf.writeInt(MessagingService.PROTOCOL_MAGIC);
-        buf.writeInt(MSG_ID); // this is the id
+        buf.writeInt(id); // this is the id
         buf.writeInt((int) 
NanoTimeToCurrentTimeMillis.convert(System.nanoTime()));
 
         msgOut.serialize(new ByteBufDataOutputPlus(buf), messagingVersion);
@@ -181,7 +183,7 @@ public class MessageInHandlerTest
         UUID uuid = UUIDGen.getTimeUUID();
         msgOut = msgOut.withParameter(ParameterType.TRACE_SESSION, uuid);
 
-        serialize(msgOut);
+        serialize(msgOut, MSG_ID);
 
         // move the write index pointer back a few bytes to simulate like the 
full bytes are not present.
         // yeah, it's lame, but it tests the basics of what is happening 
during the deserialiization
@@ -270,6 +272,52 @@ public class MessageInHandlerTest
         Assert.assertFalse(channel.isOpen());
     }
 
+    /**
+     * this is for handling the bug uncovered by CASSANDRA-14574.
+     *
+     * TL;DR if we run into a problem processing a message out an incoming 
buffer (and we close the channel, etc),
+     * do not attempt to process anymore messages from the buffer (force the 
channel closed and
+     * reject any more read attempts from the buffer).
+     *
+     * The idea here is to put several messages into a ByteBuf, pass that to 
the channel/handler, and make sure that
+     * only the initial, correct messages in the buffer are processed. After 
one messages fails the rest of the buffer
+     * should be ignored.
+     */
+    @Test
+    public void exceptionHandled_14574() throws IOException
+    {
+        Map<ParameterType, Object> parameters = new 
EnumMap<>(ParameterType.class);
+        parameters.put(ParameterType.FAILURE_RESPONSE, 
MessagingService.ONE_BYTE);
+        parameters.put(ParameterType.FAILURE_REASON, 
Shorts.checkedCast(RequestFailureReason.READ_TOO_MANY_TOMBSTONES.code));
+        MessageOut msgOut = new MessageOut<>(addr, MessagingService.Verb.ECHO, 
null, null, ImmutableList.of(), SMALL_MESSAGE);
+        for (Map.Entry<ParameterType, Object> param : parameters.entrySet())
+            msgOut = msgOut.withParameter(param.getKey(), param.getValue());
+
+        // put one complete, correct message into the buffer
+        serialize(msgOut, 1);
+
+        // add a second message, but intentionally corrupt it by manipulating 
a byte in it's range
+        int startPosition = buf.writerIndex();
+        serialize(msgOut, 2);
+        int positionToHack = startPosition + 2;
+        buf.setByte(positionToHack, buf.getByte(positionToHack) - 1);
+
+        // add one more complete, correct message into the buffer
+        serialize(msgOut, 3);
+
+        MessageIdsWrapper wrapper = new MessageIdsWrapper();
+        BaseMessageInHandler handler = getHandler(addr, messagingVersion, 
wrapper.messageConsumer);
+        EmbeddedChannel channel = new EmbeddedChannel(handler);
+        Assert.assertTrue(channel.isOpen());
+        channel.writeOneInbound(buf);
+
+        Assert.assertFalse(buf.isReadable());
+        Assert.assertEquals(BaseMessageInHandler.State.CLOSED, 
handler.getState());
+        Assert.assertFalse(channel.isOpen());
+        Assert.assertEquals(1, wrapper.ids.size());
+        Assert.assertEquals(Integer.valueOf(1), wrapper.ids.get(0));
+    }
+
     private static class MessageInWrapper
     {
         MessageIn messageIn;
@@ -281,4 +329,11 @@ public class MessageInHandlerTest
             this.id = integer;
         };
     }
+
+    private static class MessageIdsWrapper
+    {
+        private final ArrayList<Integer> ids = new ArrayList<>();
+
+        final BiConsumer<MessageIn, Integer> messageConsumer = (messageIn, 
integer) -> ids.add(integer);
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to