Repository: cassandra Updated Branches: refs/heads/trunk 5db822b71 -> 06209037e
Optimize internode messaging protocol patch by jasobrown; reviewed by Dinesh Joshi for CASSANDRA-14485 Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/06209037 Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/06209037 Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/06209037 Branch: refs/heads/trunk Commit: 06209037ea56b5a2a49615a99f1542d6ea1b2947 Parents: 5db822b Author: Jason Brown <jasedbr...@gmail.com> Authored: Tue May 29 19:21:10 2018 -0700 Committer: Jason Brown <jasedbr...@gmail.com> Committed: Mon Jun 25 06:41:30 2018 -0700 ---------------------------------------------------------------------- CHANGES.txt | 1 + .../org/apache/cassandra/net/MessageOut.java | 172 ++++++++++---- .../net/async/BaseMessageInHandler.java | 148 ++++++++++++ .../net/async/InboundHandshakeHandler.java | 6 +- .../cassandra/net/async/MessageInHandler.java | 220 +++--------------- .../net/async/MessageInHandlerPre40.java | 231 +++++++++++++++++++ .../apache/cassandra/utils/vint/VIntCoding.java | 37 +++ .../test/microbench/MessageOutBench.java | 129 +++++++++++ .../net/async/HandshakeHandlersTest.java | 31 ++- .../net/async/MessageInHandlerTest.java | 105 ++++++--- .../cassandra/utils/vint/VIntCodingTest.java | 15 ++ 11 files changed, 832 insertions(+), 263 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index fb14e40..e99c9ea 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 4.0 + * Optimize internode messaging protocol (CASSANDRA-14485) * Internode messaging handshake sends wrong messaging version number (CASSANDRA-14540) * Add a virtual table to expose active client connections (CASSANDRA-14458) * Clean up and refactor client metrics (CASSANDRA-14524) http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/src/java/org/apache/cassandra/net/MessageOut.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/MessageOut.java b/src/java/org/apache/cassandra/net/MessageOut.java index 30968df..834435e 100644 --- a/src/java/org/apache/cassandra/net/MessageOut.java +++ b/src/java/org/apache/cassandra/net/MessageOut.java @@ -18,6 +18,7 @@ package org.apache.cassandra.net; +import java.io.IOError; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -29,47 +30,43 @@ import com.google.common.primitives.Ints; import org.apache.cassandra.concurrent.Stage; import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataOutputBuffer; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType; import org.apache.cassandra.tracing.Tracing; import org.apache.cassandra.utils.FBUtilities; -import org.apache.cassandra.utils.Pair; +import org.apache.cassandra.utils.vint.VIntCoding; import static org.apache.cassandra.tracing.Tracing.isTracing; /** * Each message contains a header with several fixed fields, an optional key-value parameters section, and then - * the message payload itself. Note: the IP address in the header may be either IPv4 (4 bytes) or IPv6 (16 bytes). - * The diagram below shows the IPv4 address for brevity. + * the message payload itself. Note: the legacy IP address (pre-4.0) in the header may be either IPv4 (4 bytes) + * or IPv6 (16 bytes). The diagram below shows the IPv4 address for brevity. In pre-4.0, the payloadSize was + * encoded as a 4-byte integer; in 4.0 and up it is an unsigned byte (255 parameters should be enough for anyone). * * <pre> * {@code - * 1 1 1 1 1 2 2 2 2 2 3 3 3 3 3 4 4 4 4 4 5 5 5 5 5 6 6 - * 0 2 4 6 8 0 2 4 6 8 0 2 4 6 8 0 2 4 6 8 0 2 4 6 8 0 2 4 6 8 0 2 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | PROTOCOL MAGIC | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | Message ID | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | Timestamp | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | Addr len | IP Address (IPv4) / - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * / | Verb / - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * / | Parameters size / - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * / | Parameter data / - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * / | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | Payload size | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | / - * / Payload / - * / | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * 1 1 1 1 1 2 2 2 2 2 3 + * 0 2 4 6 8 0 2 4 6 8 0 2 4 6 8 0 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | PROTOCOL MAGIC | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Message ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Timestamp | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Verb | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |ParmLen| Parameter data (var) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Payload size (vint) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | / + * / Payload / + * / | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ * } * </pre> * @@ -163,7 +160,7 @@ public class MessageOut<T> newParameters.addAll(parameters); newParameters.add(type); newParameters.add(value); - return new MessageOut<T>(verb, payload, serializer, newParameters, connectionType); + return new MessageOut<T>(from, verb, payload, serializer, newParameters, connectionType); } public Stage getStage() @@ -185,25 +182,60 @@ public class MessageOut<T> public void serialize(DataOutputPlus out, int version) throws IOException { - CompactEndpointSerializationHelper.instance.serialize(from, out, version); + if (version >= MessagingService.VERSION_40) + serialize40(out, version); + else + serializePre40(out, version); + } + private void serialize40(DataOutputPlus out, int version) throws IOException + { out.writeInt(verb.getId()); + + // serialize the headers, if any assert parameters.size() % PARAMETER_TUPLE_SIZE == 0; - out.writeInt(parameters.size() / PARAMETER_TUPLE_SIZE); - for (int ii = 0; ii < parameters.size(); ii += PARAMETER_TUPLE_SIZE) + if (parameters.isEmpty()) { - ParameterType type = (ParameterType)parameters.get(ii + PARAMETER_TUPLE_TYPE_OFFSET); - out.writeUTF(type.key); - IVersionedSerializer serializer = type.serializer; - Object parameter = parameters.get(ii + PARAMETER_TUPLE_PARAMETER_OFFSET); - out.writeInt(Ints.checkedCast(serializer.serializedSize(parameter, version))); - serializer.serialize(parameter, out, version); + out.writeVInt(0); } + else + { + try (DataOutputBuffer buf = new DataOutputBuffer()) + { + serializeParams(buf, version); + out.writeUnsignedVInt(buf.getLength()); + out.write(buf.buffer()); + } + } + + if (payload != null) + { + int payloadSize = payloadSerializedSize >= 0 + ? payloadSerializedSize + : (int) serializer.serializedSize(payload, version); + + out.writeUnsignedVInt(payloadSize); + serializer.serialize(payload, out, version); + } + else + { + out.writeUnsignedVInt(0); + } + } + + private void serializePre40(DataOutputPlus out, int version) throws IOException + { + CompactEndpointSerializationHelper.instance.serialize(from, out, version); + out.writeInt(verb.getId()); + + assert parameters.size() % PARAMETER_TUPLE_SIZE == 0; + out.writeInt(parameters.size() / PARAMETER_TUPLE_SIZE); + serializeParams(out, version); if (payload != null) { int payloadSize = payloadSerializedSize >= 0 - ? (int)payloadSerializedSize + ? payloadSerializedSize : (int) serializer.serializedSize(payload, version); out.writeInt(payloadSize); @@ -215,13 +247,73 @@ public class MessageOut<T> } } + private void serializeParams(DataOutputPlus out, int version) throws IOException + { + for (int ii = 0; ii < parameters.size(); ii += PARAMETER_TUPLE_SIZE) + { + ParameterType type = (ParameterType)parameters.get(ii + PARAMETER_TUPLE_TYPE_OFFSET); + out.writeUTF(type.key); + IVersionedSerializer serializer = type.serializer; + Object parameter = parameters.get(ii + PARAMETER_TUPLE_PARAMETER_OFFSET); + + int valueLength = Ints.checkedCast(serializer.serializedSize(parameter, version)); + if (version >= MessagingService.VERSION_40) + out.writeUnsignedVInt(valueLength); + else + out.writeInt(valueLength); + + serializer.serialize(parameter, out, version); + } + } + private MessageOutSizes calculateSerializedSize(int version) { + return version >= MessagingService.VERSION_40 + ? calculateSerializedSize40(version) + : calculateSerializedSizePre40(version); + } + + private MessageOutSizes calculateSerializedSize40(int version) + { + long size = 0; + size += TypeSizes.sizeof(verb.getId()); + + if (parameters.isEmpty()) + { + size += VIntCoding.computeVIntSize(0); + } + else + { + // calculate the params size independently, as we write that before the actual params block + int paramsSize = 0; + for (int ii = 0; ii < parameters.size(); ii += PARAMETER_TUPLE_SIZE) + { + ParameterType type = (ParameterType)parameters.get(ii + PARAMETER_TUPLE_TYPE_OFFSET); + paramsSize += TypeSizes.sizeof(type.key()); + IVersionedSerializer serializer = type.serializer; + Object parameter = parameters.get(ii + PARAMETER_TUPLE_PARAMETER_OFFSET); + int valueLength = Ints.checkedCast(serializer.serializedSize(parameter, version)); + paramsSize += VIntCoding.computeUnsignedVIntSize(valueLength);//length prefix + paramsSize += valueLength; + } + size += VIntCoding.computeUnsignedVIntSize(paramsSize); + size += paramsSize; + } + + long payloadSize = payload == null ? 0 : serializer.serializedSize(payload, version); + assert payloadSize <= Integer.MAX_VALUE; // larger values are supported in sstables but not messages + size += VIntCoding.computeUnsignedVIntSize(payloadSize); + size += payloadSize; + return new MessageOutSizes(size, payloadSize); + } + + private MessageOutSizes calculateSerializedSizePre40(int version) + { long size = 0; size += CompactEndpointSerializationHelper.instance.serializedSize(from, version); size += TypeSizes.sizeof(verb.getId()); - size += TypeSizes.sizeof(parameters.size()); + size += TypeSizes.sizeof(parameters.size() / PARAMETER_TUPLE_SIZE); for (int ii = 0; ii < parameters.size(); ii += PARAMETER_TUPLE_SIZE) { ParameterType type = (ParameterType)parameters.get(ii + PARAMETER_TUPLE_TYPE_OFFSET); http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java b/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java new file mode 100644 index 0000000..7314999 --- /dev/null +++ b/src/java/org/apache/cassandra/net/async/BaseMessageInHandler.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.EOFException; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import com.google.common.annotations.VisibleForTesting; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import org.apache.cassandra.db.monitoring.ApproximateTime; +import org.apache.cassandra.exceptions.UnknownTableException; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.ParameterType; + +public abstract class BaseMessageInHandler extends ByteToMessageDecoder +{ + public static final Logger logger = LoggerFactory.getLogger(BaseMessageInHandler.class); + + enum State + { + READ_FIRST_CHUNK, + READ_IP_ADDRESS, + READ_VERB, + READ_PARAMETERS_SIZE, + READ_PARAMETERS_DATA, + READ_PAYLOAD_SIZE, + READ_PAYLOAD + } + + /** + * The byte count for magic, msg id, timestamp values. + */ + @VisibleForTesting + static final int FIRST_SECTION_BYTE_COUNT = 12; + + static final int VERB_LENGTH = Integer.BYTES; + + /** + * The default target for consuming deserialized {@link MessageIn}. + */ + static final BiConsumer<MessageIn, Integer> MESSAGING_SERVICE_CONSUMER = (messageIn, id) -> MessagingService.instance().receive(messageIn, id); + + /** + * Abstracts out depending directly on {@link MessagingService#receive(MessageIn, int)}; this makes tests more sane + * as they don't require nor trigger the entire message processing circus. + */ + final BiConsumer<MessageIn, Integer> messageConsumer; + + final InetAddressAndPort peer; + final int messagingVersion; + + public BaseMessageInHandler(InetAddressAndPort peer, int messagingVersion, BiConsumer<MessageIn, Integer> messageConsumer) + { + this.peer = peer; + this.messagingVersion = messagingVersion; + this.messageConsumer = messageConsumer; + } + + public abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out); + + MessageHeader readFirstChunk(ByteBuf in) throws IOException + { + if (in.readableBytes() < FIRST_SECTION_BYTE_COUNT) + return null; + MessagingService.validateMagic(in.readInt()); + MessageHeader messageHeader = new MessageInHandler.MessageHeader(); + messageHeader.messageId = in.readInt(); + int messageTimestamp = in.readInt(); // make sure to read the sent timestamp, even if DatabaseDescriptor.hasCrossNodeTimeout() is not enabled + messageHeader.constructionTime = MessageIn.deriveConstructionTime(peer, messageTimestamp, ApproximateTime.currentTimeMillis()); + + return messageHeader; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + { + if (cause instanceof EOFException) + logger.trace("eof reading from socket; closing", cause); + else if (cause instanceof UnknownTableException) + logger.warn("Got message from unknown table while reading from socket; closing", cause); + else if (cause instanceof IOException) + logger.trace("IOException reading from socket; closing", cause); + else + logger.warn("Unexpected exception caught in inbound channel pipeline from " + ctx.channel().remoteAddress(), cause); + + ctx.close(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) + { + logger.trace("received channel closed message for peer {} on local addr {}", ctx.channel().remoteAddress(), ctx.channel().localAddress()); + ctx.fireChannelInactive(); + } + + // should ony be used for testing!!! + @VisibleForTesting + abstract MessageHeader getMessageHeader(); + + /** + * A simple struct to hold the message header data as it is being built up. + */ + protected static class MessageHeader + { + int messageId; + long constructionTime; + InetAddressAndPort from; + MessagingService.Verb verb; + int payloadSize; + + Map<ParameterType, Object> parameters = Collections.emptyMap(); + + /** + * Length of the parameter data. If the message's version is {@link MessagingService#VERSION_40} or higher, + * this value is the total number of header bytes; else, for legacy messaging, this is the number of + * key/value entries in the header. + */ + int parameterLength; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java b/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java index 656680f..e66a589 100644 --- a/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java +++ b/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java @@ -260,7 +260,11 @@ class InboundHandshakeHandler extends ByteToMessageDecoder if (compressed) pipeline.addLast(NettyFactory.INBOUND_COMPRESSOR_HANDLER_NAME, NettyFactory.createLz4Decoder(messagingVersion)); - pipeline.addLast("messageInHandler", new MessageInHandler(peer, messagingVersion)); + BaseMessageInHandler messageInHandler = messagingVersion >= MessagingService.VERSION_40 + ? new MessageInHandler(peer, messagingVersion) + : new MessageInHandlerPre40(peer, messagingVersion); + + pipeline.addLast("messageInHandler", messageInHandler); pipeline.remove(this); } http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/src/java/org/apache/cassandra/net/async/MessageInHandler.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/MessageInHandler.java b/src/java/org/apache/cassandra/net/async/MessageInHandler.java index b9cbd1a..c85d860 100644 --- a/src/java/org/apache/cassandra/net/async/MessageInHandler.java +++ b/src/java/org/apache/cassandra/net/async/MessageInHandler.java @@ -19,15 +19,14 @@ package org.apache.cassandra.net.async; import java.io.DataInputStream; -import java.io.EOFException; import java.io.IOException; import java.util.Collections; -import java.util.HashMap; +import java.util.EnumMap; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; -import com.google.common.annotations.VisibleForTesting; +import com.google.common.primitives.Ints; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,13 +34,12 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; import org.apache.cassandra.db.monitoring.ApproximateTime; -import org.apache.cassandra.exceptions.UnknownTableException; import org.apache.cassandra.io.util.DataInputBuffer; import org.apache.cassandra.locator.InetAddressAndPort; -import org.apache.cassandra.net.CompactEndpointSerializationHelper; import org.apache.cassandra.net.MessageIn; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.net.ParameterType; +import org.apache.cassandra.utils.vint.VIntCoding; /** * Parses out individual messages from the incoming buffers. Each message, both header and payload, is incrementally built up @@ -51,45 +49,10 @@ import org.apache.cassandra.net.ParameterType; * behavior across {@link #decode(ChannelHandlerContext, ByteBuf, List)} invocations. That way we don't have to maintain * the not-fully consumed {@link ByteBuf}s. */ -class MessageInHandler extends ByteToMessageDecoder +public class MessageInHandler extends BaseMessageInHandler { public static final Logger logger = LoggerFactory.getLogger(MessageInHandler.class); - /** - * The default target for consuming deserialized {@link MessageIn}. - */ - static final BiConsumer<MessageIn, Integer> MESSAGING_SERVICE_CONSUMER = (messageIn, id) -> MessagingService.instance().receive(messageIn, id); - - private enum State - { - READ_FIRST_CHUNK, - READ_IP_ADDRESS, - READ_SECOND_CHUNK, - READ_PARAMETERS_DATA, - READ_PAYLOAD_SIZE, - READ_PAYLOAD - } - - /** - * The byte count for magic, msg id, timestamp values. - */ - @VisibleForTesting - static final int FIRST_SECTION_BYTE_COUNT = 12; - - /** - * The byte count for the verb id and the number of parameters. - */ - private static final int SECOND_SECTION_BYTE_COUNT = 8; - - private final InetAddressAndPort peer; - private final int messagingVersion; - - /** - * Abstracts out depending directly on {@link MessagingService#receive(MessageIn, int)}; this makes tests more sane - * as they don't require nor trigger the entire message processing circus. - */ - private final BiConsumer<MessageIn, Integer> messageConsumer; - private State state; private MessageHeader messageHeader; @@ -98,11 +61,13 @@ class MessageInHandler extends ByteToMessageDecoder this (peer, messagingVersion, MESSAGING_SERVICE_CONSUMER); } - MessageInHandler(InetAddressAndPort peer, int messagingVersion, BiConsumer<MessageIn, Integer> messageConsumer) + public MessageInHandler(InetAddressAndPort peer, int messagingVersion, BiConsumer<MessageIn, Integer> messageConsumer) { - this.peer = peer; - this.messagingVersion = messagingVersion; - this.messageConsumer = messageConsumer; + super(peer, messagingVersion, messageConsumer); + + if (messagingVersion < MessagingService.VERSION_40) + throw new IllegalArgumentException(String.format("wrong messaging version for this handler", messagingVersion)); + state = State.READ_FIRST_CHUNK; } @@ -119,64 +84,51 @@ class MessageInHandler extends ByteToMessageDecoder { while (true) { - // an imperfect optimization around calling in.readableBytes() all the time - int readableBytes = in.readableBytes(); - switch (state) { case READ_FIRST_CHUNK: - if (readableBytes < FIRST_SECTION_BYTE_COUNT) + MessageHeader header = readFirstChunk(in); + if (header == null) return; - MessagingService.validateMagic(in.readInt()); - messageHeader = new MessageHeader(); - messageHeader.messageId = in.readInt(); - int messageTimestamp = in.readInt(); // make sure to read the sent timestamp, even if DatabaseDescriptor.hasCrossNodeTimeout() is not enabled - messageHeader.constructionTime = MessageIn.deriveConstructionTime(peer, messageTimestamp, ApproximateTime.currentTimeMillis()); - state = State.READ_IP_ADDRESS; - readableBytes -= FIRST_SECTION_BYTE_COUNT; + header.from = peer; + messageHeader = header; + state = State.READ_VERB; // fall-through - case READ_IP_ADDRESS: - // unfortunately, this assumes knowledge of how CompactEndpointSerializationHelper serializes data (the first byte is the size). - // first, check that we can actually read the size byte, then check if we can read that number of bytes. - // the "+ 1" is to make sure we have the size byte in addition to the serialized IP addr count of bytes in the buffer. - int serializedAddrSize; - if (readableBytes < 1 || readableBytes < (serializedAddrSize = in.getByte(in.readerIndex()) + 1)) + case READ_VERB: + if (in.readableBytes() < VERB_LENGTH) return; - messageHeader.from = CompactEndpointSerializationHelper.instance.deserialize(inputPlus, messagingVersion); - state = State.READ_SECOND_CHUNK; - readableBytes -= serializedAddrSize; + messageHeader.verb = MessagingService.Verb.fromId(in.readInt()); + state = State.READ_PARAMETERS_SIZE; // fall-through - case READ_SECOND_CHUNK: - if (readableBytes < SECOND_SECTION_BYTE_COUNT) + case READ_PARAMETERS_SIZE: + long length = VIntCoding.readUnsignedVInt(in); + if (length < 0) return; - messageHeader.verb = MessagingService.Verb.fromId(in.readInt()); - int paramCount = in.readInt(); - messageHeader.parameterCount = paramCount; - messageHeader.parameters = paramCount == 0 ? Collections.emptyMap() : new HashMap<>(); + messageHeader.parameterLength = (int) length; + messageHeader.parameters = messageHeader.parameterLength == 0 ? Collections.emptyMap() : new EnumMap<>(ParameterType.class); state = State.READ_PARAMETERS_DATA; - readableBytes -= SECOND_SECTION_BYTE_COUNT; // fall-through case READ_PARAMETERS_DATA: - if (messageHeader.parameterCount > 0) + if (messageHeader.parameterLength > 0) { - if (!readParameters(in, inputPlus, messageHeader.parameterCount, messageHeader.parameters)) + if (in.readableBytes() < messageHeader.parameterLength) return; - readableBytes = in.readableBytes(); // we read an indeterminate number of bytes for the headers, so just ask the buffer again + readParameters(in, inputPlus, messageHeader.parameterLength, messageHeader.parameters); } state = State.READ_PAYLOAD_SIZE; // fall-through case READ_PAYLOAD_SIZE: - if (readableBytes < 4) + length = VIntCoding.readUnsignedVInt(in); + if (length < 0) return; - messageHeader.payloadSize = in.readInt(); + messageHeader.payloadSize = (int) length; state = State.READ_PAYLOAD; - readableBytes -= 4; // fall-through case READ_PAYLOAD: - if (readableBytes < messageHeader.payloadSize) + if (in.readableBytes() < messageHeader.payloadSize) return; - // TODO consider deserailizing the messge not on the event loop + // TODO consider deserializing the message not on the event loop MessageIn<Object> messageIn = MessageIn.read(inputPlus, messagingVersion, messageHeader.messageId, messageHeader.constructionTime, messageHeader.from, messageHeader.payloadSize, messageHeader.verb, messageHeader.parameters); @@ -198,123 +150,27 @@ class MessageInHandler extends ByteToMessageDecoder } } - /** - * @return <code>true</code> if all the parameters have been read from the {@link ByteBuf}; else, <code>false</code>. - */ - private boolean readParameters(ByteBuf in, ByteBufDataInputPlus inputPlus, int parameterCount, Map<ParameterType, Object> parameters) throws IOException + private void readParameters(ByteBuf in, ByteBufDataInputPlus inputPlus, int parameterLength, Map<ParameterType, Object> parameters) throws IOException { - // makes the assumption that map.size() is a constant time function (HashMap.size() is) - while (parameters.size() < parameterCount) + // makes the assumption we have all the bytes required to read the headers + final int endIndex = in.readerIndex() + parameterLength; + while (in.readerIndex() < endIndex) { - if (!canReadNextParam(in)) - return false; - String key = DataInputStream.readUTF(inputPlus); ParameterType parameterType = ParameterType.byName.get(key); - byte[] value = new byte[in.readInt()]; + long valueLength = VIntCoding.readUnsignedVInt(in); + byte[] value = new byte[Ints.checkedCast(valueLength)]; in.readBytes(value); try (DataInputBuffer buffer = new DataInputBuffer(value)) { parameters.put(parameterType, parameterType.serializer.deserialize(buffer, messagingVersion)); } } - - return true; - } - - /** - * Determine if we can read the next parameter from the {@link ByteBuf}. This method will *always* set the {@code in} - * readIndex back to where it was when this method was invoked. - * - * NOTE: this function would be sooo much simpler if we included a parameters length int in the messaging format, - * instead of checking the remaining readable bytes for each field as we're parsing it. c'est la vie ... - */ - @VisibleForTesting - static boolean canReadNextParam(ByteBuf in) - { - in.markReaderIndex(); - // capture the readableBytes value here to avoid all the virtual function calls. - // subtract 6 as we know we'll be reading a short and an int (for the utf and value lengths). - final int minimumBytesRequired = 6; - int readableBytes = in.readableBytes() - minimumBytesRequired; - if (readableBytes < 0) - return false; - - // this is a tad invasive, but since we know the UTF string is prefaced with a 2-byte length, - // read that to make sure we have enough bytes to read the string itself. - short strLen = in.readShort(); - // check if we can read that many bytes for the UTF - if (strLen > readableBytes) - { - in.resetReaderIndex(); - return false; - } - in.skipBytes(strLen); - readableBytes -= strLen; - - // check if we can read the value length - if (readableBytes < 4) - { - in.resetReaderIndex(); - return false; - } - int valueLength = in.readInt(); - // check if we read that many bytes for the value - if (valueLength > readableBytes) - { - in.resetReaderIndex(); - return false; - } - - in.resetReaderIndex(); - return true; } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) - { - if (cause instanceof EOFException) - logger.trace("eof reading from socket; closing", cause); - else if (cause instanceof UnknownTableException) - logger.warn("Got message from unknown table while reading from socket; closing", cause); - else if (cause instanceof IOException) - logger.trace("IOException reading from socket; closing", cause); - else - logger.warn("Unexpected exception caught in inbound channel pipeline from " + ctx.channel().remoteAddress(), cause); - - ctx.close(); - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception - { - logger.debug("received channel closed message for peer {} on local addr {}", ctx.channel().remoteAddress(), ctx.channel().localAddress()); - ctx.fireChannelInactive(); - } - - // should ony be used for testing!!! - @VisibleForTesting MessageHeader getMessageHeader() { return messageHeader; } - - /** - * A simple struct to hold the message header data as it is being built up. - */ - static class MessageHeader - { - int messageId; - long constructionTime; - InetAddressAndPort from; - MessagingService.Verb verb; - int payloadSize; - - Map<ParameterType, Object> parameters = Collections.emptyMap(); - - /** - * Total number of incoming parameters. - */ - int parameterCount; - } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java b/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java new file mode 100644 index 0000000..132ec11 --- /dev/null +++ b/src/java/org/apache/cassandra/net/async/MessageInHandlerPre40.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.DataInputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import org.apache.cassandra.io.util.DataInputBuffer; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.CompactEndpointSerializationHelper; +import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.ParameterType; + +public class MessageInHandlerPre40 extends BaseMessageInHandler +{ + public static final Logger logger = LoggerFactory.getLogger(MessageInHandlerPre40.class); + + static final int PARAMETERS_SIZE_LENGTH = Integer.BYTES; + static final int PARAMETERS_VALUE_SIZE_LENGTH = Integer.BYTES; + static final int PAYLOAD_SIZE_LENGTH = Integer.BYTES; + + private State state; + private MessageHeader messageHeader; + + MessageInHandlerPre40(InetAddressAndPort peer, int messagingVersion) + { + this (peer, messagingVersion, MESSAGING_SERVICE_CONSUMER); + } + + public MessageInHandlerPre40(InetAddressAndPort peer, int messagingVersion, BiConsumer<MessageIn, Integer> messageConsumer) + { + super(peer, messagingVersion, messageConsumer); + + if (messagingVersion >= MessagingService.VERSION_40) + throw new IllegalArgumentException(String.format("wrong messaging version for this handler", messagingVersion)); + + state = State.READ_FIRST_CHUNK; + } + + /** + * For each new message coming in, builds up a {@link MessageHeader} instance incrementally. This method + * attempts to deserialize as much header information as it can out of the incoming {@link ByteBuf}, and + * maintains a trivial state machine to remember progress across invocations. + */ + @SuppressWarnings("resource") + public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) + { + ByteBufDataInputPlus inputPlus = new ByteBufDataInputPlus(in); + try + { + while (true) + { + switch (state) + { + case READ_FIRST_CHUNK: + MessageHeader header = readFirstChunk(in); + if (header == null) + return; + messageHeader = header; + state = State.READ_IP_ADDRESS; + // fall-through + case READ_IP_ADDRESS: + // unfortunately, this assumes knowledge of how CompactEndpointSerializationHelper serializes data (the first byte is the size). + // first, check that we can actually read the size byte, then check if we can read that number of bytes. + // the "+ 1" is to make sure we have the size byte in addition to the serialized IP addr count of bytes in the buffer. + int readableBytes = in.readableBytes(); + if (readableBytes < 1 || readableBytes < in.getByte(in.readerIndex()) + 1) + return; + messageHeader.from = CompactEndpointSerializationHelper.instance.deserialize(inputPlus, messagingVersion); + state = State.READ_VERB; + // fall-through + case READ_VERB: + if (in.readableBytes() < VERB_LENGTH) + return; + messageHeader.verb = MessagingService.Verb.fromId(in.readInt()); + state = State.READ_PARAMETERS_SIZE; + // fall-through + case READ_PARAMETERS_SIZE: + if (in.readableBytes() < PARAMETERS_SIZE_LENGTH) + return; + messageHeader.parameterLength = in.readInt(); + messageHeader.parameters = messageHeader.parameterLength == 0 ? Collections.emptyMap() : new EnumMap<>(ParameterType.class); + state = State.READ_PARAMETERS_DATA; + // fall-through + case READ_PARAMETERS_DATA: + if (messageHeader.parameterLength > 0) + { + if (!readParameters(in, inputPlus, messageHeader.parameterLength, messageHeader.parameters)) + return; + } + state = State.READ_PAYLOAD_SIZE; + // fall-through + case READ_PAYLOAD_SIZE: + if (in.readableBytes() < PAYLOAD_SIZE_LENGTH) + return; + messageHeader.payloadSize = in.readInt(); + state = State.READ_PAYLOAD; + // fall-through + case READ_PAYLOAD: + if (in.readableBytes() < messageHeader.payloadSize) + return; + + // TODO consider deserailizing the messge not on the event loop + MessageIn<Object> messageIn = MessageIn.read(inputPlus, messagingVersion, + messageHeader.messageId, messageHeader.constructionTime, messageHeader.from, + messageHeader.payloadSize, messageHeader.verb, messageHeader.parameters); + + if (messageIn != null) + messageConsumer.accept(messageIn, messageHeader.messageId); + + state = State.READ_FIRST_CHUNK; + messageHeader = null; + break; + default: + throw new IllegalStateException("unknown/unhandled state: " + state); + } + } + } + catch (Exception e) + { + exceptionCaught(ctx, e); + } + } + + /** + * @return <code>true</code> if all the parameters have been read from the {@link ByteBuf}; else, <code>false</code>. + */ + private boolean readParameters(ByteBuf in, ByteBufDataInputPlus inputPlus, int parameterCount, Map<ParameterType, Object> parameters) throws IOException + { + // makes the assumption that map.size() is a constant time function (HashMap.size() is) + while (parameters.size() < parameterCount) + { + if (!canReadNextParam(in)) + return false; + + String key = DataInputStream.readUTF(inputPlus); + ParameterType parameterType = ParameterType.byName.get(key); + byte[] value = new byte[in.readInt()]; + in.readBytes(value); + try (DataInputBuffer buffer = new DataInputBuffer(value)) + { + parameters.put(parameterType, parameterType.serializer.deserialize(buffer, messagingVersion)); + } + } + + return true; + } + + /** + * Determine if we can read the next parameter from the {@link ByteBuf}. This method will *always* set the {@code in} + * readIndex back to where it was when this method was invoked. + * + * NOTE: this function would be sooo much simpler if we included a parameters length int in the messaging format, + * instead of checking the remaining readable bytes for each field as we're parsing it. c'est la vie ... + */ + @VisibleForTesting + static boolean canReadNextParam(ByteBuf in) + { + in.markReaderIndex(); + // capture the readableBytes value here to avoid all the virtual function calls. + // subtract 6 as we know we'll be reading a short and an int (for the utf and value lengths). + final int minimumBytesRequired = 6; + int readableBytes = in.readableBytes() - minimumBytesRequired; + if (readableBytes < 0) + return false; + + // this is a tad invasive, but since we know the UTF string is prefaced with a 2-byte length, + // read that to make sure we have enough bytes to read the string itself. + short strLen = in.readShort(); + // check if we can read that many bytes for the UTF + if (strLen > readableBytes) + { + in.resetReaderIndex(); + return false; + } + in.skipBytes(strLen); + readableBytes -= strLen; + + // check if we can read the value length + if (readableBytes < PARAMETERS_VALUE_SIZE_LENGTH) + { + in.resetReaderIndex(); + return false; + } + int valueLength = in.readInt(); + // check if we read that many bytes for the value + if (valueLength > readableBytes) + { + in.resetReaderIndex(); + return false; + } + + in.resetReaderIndex(); + return true; + } + + + @Override + MessageHeader getMessageHeader() + { + return messageHeader; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/src/java/org/apache/cassandra/utils/vint/VIntCoding.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/utils/vint/VIntCoding.java b/src/java/org/apache/cassandra/utils/vint/VIntCoding.java index a8a1654..67444a9 100644 --- a/src/java/org/apache/cassandra/utils/vint/VIntCoding.java +++ b/src/java/org/apache/cassandra/utils/vint/VIntCoding.java @@ -50,6 +50,7 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import io.netty.buffer.ByteBuf; import io.netty.util.concurrent.FastThreadLocal; import net.nicoulaj.compilecommand.annotations.Inline; @@ -81,6 +82,42 @@ public class VIntCoding return retval; } + /** + * Note this method is the same as {@link #readUnsignedVInt(DataInput)}, + * except that we do *not* block if there are not enough bytes in the buffer + * to reconstruct the value. + */ + public static long readUnsignedVInt(ByteBuf input) + { + if (!input.isReadable()) + return -1; + + input.markReaderIndex(); + int firstByte = input.readByte(); + + //Bail out early if this is one byte, necessary or it fails later + if (firstByte >= 0) + return firstByte; + + int size = numberOfExtraBytesToRead(firstByte); + + if (input.readableBytes() < size) + { + input.resetReaderIndex(); + return -1; + } + + long retval = firstByte & firstByteValueMask(size); + for (int ii = 0; ii < size; ii++) + { + byte b = input.readByte(); + retval <<= 8; + retval |= b & 0xff; + } + + return retval; + } + public static long readVInt(DataInput input) throws IOException { return decodeZigZag64(readUnsignedVInt(input)); http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java ---------------------------------------------------------------------- diff --git a/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java b/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java new file mode 100644 index 0000000..43b0c16 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/MessageOutBench.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench; + +import java.io.IOException; +import java.util.Collections; +import java.util.EnumMap; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; + +import com.google.common.collect.ImmutableList; +import com.google.common.net.InetAddresses; +import com.google.common.primitives.Shorts; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.exceptions.RequestFailureReason; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.ParameterType; +import org.apache.cassandra.net.async.BaseMessageInHandler; +import org.apache.cassandra.net.async.ByteBufDataOutputPlus; +import org.apache.cassandra.net.async.MessageInHandler; +import org.apache.cassandra.net.async.MessageInHandlerPre40; +import org.apache.cassandra.utils.NanoTimeToCurrentTimeMillis; +import org.apache.cassandra.utils.UUIDGen; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import static org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType.SMALL_MESSAGE; + +@State(Scope.Thread) +@Warmup(iterations = 4, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 8, time = 4, timeUnit = TimeUnit.SECONDS) +@Fork(value = 1,jvmArgsAppend = "-Xmx512M") +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.SampleTime) +public class MessageOutBench +{ + @Param({ "true", "false" }) + private boolean withParams; + + private MessageOut msgOut; + private ByteBuf buf; + BaseMessageInHandler handler40; + BaseMessageInHandler handlerPre40; + + @Setup + public void setup() + { + DatabaseDescriptor.daemonInitialization(); + InetAddressAndPort addr = InetAddressAndPort.getByAddress(InetAddresses.forString("127.0.73.101")); + + UUID uuid = UUIDGen.getTimeUUID(); + Map<ParameterType, Object> parameters = new EnumMap<>(ParameterType.class); + + if (withParams) + { + parameters.put(ParameterType.FAILURE_RESPONSE, MessagingService.ONE_BYTE); + parameters.put(ParameterType.FAILURE_REASON, Shorts.checkedCast(RequestFailureReason.READ_TOO_MANY_TOMBSTONES.code)); + parameters.put(ParameterType.TRACE_SESSION, uuid); + } + + msgOut = new MessageOut<>(addr, MessagingService.Verb.ECHO, null, null, ImmutableList.of(), SMALL_MESSAGE); + buf = Unpooled.buffer(1024, 1024); // 1k should be enough for everybody! + + handler40 = new MessageInHandler(addr, MessagingService.VERSION_40, messageConsumer); + handlerPre40 = new MessageInHandlerPre40(addr, MessagingService.VERSION_30, messageConsumer); + } + + @Benchmark + public int serialize40() throws IOException + { + return serialize(MessagingService.VERSION_40, handler40); + } + + private int serialize(int messagingVersion, BaseMessageInHandler handler) throws IOException + { + buf.resetReaderIndex(); + buf.resetWriterIndex(); + buf.writeInt(MessagingService.PROTOCOL_MAGIC); + buf.writeInt(42); // this is the id + buf.writeInt((int) NanoTimeToCurrentTimeMillis.convert(System.nanoTime())); + + msgOut.serialize(new ByteBufDataOutputPlus(buf), messagingVersion); + handler.decode(null, buf, Collections.emptyList()); + return msgOut.serializedSize(messagingVersion); + } + + @Benchmark + public int serializePre40() throws IOException + { + return serialize(MessagingService.VERSION_30, handlerPre40); + } + + private final BiConsumer<MessageIn, Integer> messageConsumer = (messageIn, integer) -> + { + }; +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java b/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java index f92ce5a..087f49e 100644 --- a/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java +++ b/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java @@ -19,14 +19,17 @@ package org.apache.cassandra.net.async; import java.io.IOException; -import java.net.InetSocketAddress; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Optional; import com.google.common.net.InetAddresses; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; import io.netty.channel.embedded.EmbeddedChannel; import org.apache.cassandra.SchemaLoader; @@ -47,6 +50,7 @@ import org.apache.cassandra.schema.KeyspaceParams; import static org.apache.cassandra.net.async.InboundHandshakeHandler.State.HANDSHAKE_COMPLETE; import static org.apache.cassandra.net.async.OutboundMessagingConnection.State.READY; +@RunWith(Parameterized.class) public class HandshakeHandlersTest { private static final String KEYSPACE1 = "NettyPipilineTest"; @@ -54,8 +58,8 @@ public class HandshakeHandlersTest private static final InetAddressAndPort LOCAL_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 9999); private static final InetAddressAndPort REMOTE_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 9999); - private static final int MESSAGING_VERSION = MessagingService.current_version; private static final OutboundConnectionIdentifier connectionId = OutboundConnectionIdentifier.small(LOCAL_ADDR, REMOTE_ADDR); + private final int messagingVersion; @BeforeClass public static void beforeClass() throws ConfigurationException @@ -67,6 +71,17 @@ public class HandshakeHandlersTest CompactionManager.instance.disableAutoCompaction(); } + public HandshakeHandlersTest(int messagingVersion) + { + this.messagingVersion = messagingVersion; + } + + @Parameters() + public static Iterable<?> generateData() + { + return Arrays.asList(MessagingService.VERSION_30, MessagingService.VERSION_40); + } + @Test public void handshake_HappyPath() { @@ -169,19 +184,20 @@ public class HandshakeHandlersTest .compress(compress) .coalescingStrategy(Optional.empty()) .protocolVersion(MessagingService.current_version) + .backlogSupplier(this::nopBacklog) .build(); OutboundHandshakeHandler outboundHandshakeHandler = new OutboundHandshakeHandler(params); EmbeddedChannel outboundChannel = new EmbeddedChannel(outboundHandshakeHandler); OutboundMessagingConnection omc = new OutboundMessagingConnection(connectionId, null, Optional.empty(), new AllowAllInternodeAuthenticator()); - omc.setTargetVersion(MESSAGING_VERSION); - outboundHandshakeHandler.setupPipeline(outboundChannel, MESSAGING_VERSION); + omc.setTargetVersion(messagingVersion); + outboundHandshakeHandler.setupPipeline(outboundChannel, messagingVersion); // remove the outbound handshake message from the outbound messages outboundChannel.outboundMessages().clear(); InboundHandshakeHandler handler = new InboundHandshakeHandler(new TestAuthenticator(true)); EmbeddedChannel inboundChannel = new EmbeddedChannel(handler); - handler.setupMessagingPipeline(inboundChannel.pipeline(), REMOTE_ADDR, compress, MESSAGING_VERSION); + handler.setupMessagingPipeline(inboundChannel.pipeline(), REMOTE_ADDR, compress, messagingVersion); return new TestChannels(outboundChannel, inboundChannel); } @@ -203,4 +219,9 @@ public class HandshakeHandlersTest // do nothing, really return null; } + + private QueuedMessage nopBacklog() + { + return null; + } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java b/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java index 43cb964..16f4faf 100644 --- a/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java +++ b/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java @@ -21,38 +21,47 @@ package org.apache.cassandra.net.async; import java.io.EOFException; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; +import java.util.EnumMap; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.function.BiConsumer; +import com.google.common.collect.ImmutableList; import com.google.common.net.InetAddresses; +import com.google.common.primitives.Shorts; import org.junit.After; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.exceptions.RequestFailureReason; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.net.MessageIn; import org.apache.cassandra.net.MessageOut; import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.net.ParameterType; -import org.apache.cassandra.net.async.MessageInHandler.MessageHeader; import org.apache.cassandra.utils.NanoTimeToCurrentTimeMillis; import org.apache.cassandra.utils.UUIDGen; +import static org.apache.cassandra.net.async.OutboundConnectionIdentifier.ConnectionType.SMALL_MESSAGE; + +@RunWith(Parameterized.class) public class MessageInHandlerTest { - private static final InetAddressAndPort addr = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 0); - private static final int MSG_VERSION = MessagingService.current_version; - private static final int MSG_ID = 42; + private static InetAddressAndPort addr; + + private final int messagingVersion; private ByteBuf buf; @@ -60,6 +69,18 @@ public class MessageInHandlerTest public static void before() { DatabaseDescriptor.daemonInitialization(); + addr = InetAddressAndPort.getByAddress(InetAddresses.forString("127.0.73.101")); + } + + public MessageInHandlerTest(int messagingVersion) + { + this.messagingVersion = messagingVersion; + } + + @Parameters() + public static Iterable<?> generateData() + { + return Arrays.asList(MessagingService.VERSION_30, MessagingService.VERSION_40); } @After @@ -69,15 +90,23 @@ public class MessageInHandlerTest buf.release(); } + private BaseMessageInHandler getHandler(InetAddressAndPort addr, int messagingVersion, BiConsumer<MessageIn, Integer> messageConsumer) + { + if (messagingVersion >= MessagingService.VERSION_40) + return new MessageInHandler(addr, messagingVersion, messageConsumer); + return new MessageInHandlerPre40(addr, messagingVersion, messageConsumer); + } + + @Test - public void decode_BadMagic() throws Exception + public void decode_BadMagic() { int len = MessageInHandler.FIRST_SECTION_BYTE_COUNT; buf = Unpooled.buffer(len, len); buf.writeInt(-1); buf.writerIndex(len); - MessageInHandler handler = new MessageInHandler(addr, MSG_VERSION, null); + BaseMessageInHandler handler = getHandler(addr, messagingVersion, null); EmbeddedChannel channel = new EmbeddedChannel(handler); Assert.assertTrue(channel.isOpen()); channel.writeInbound(buf); @@ -95,24 +124,26 @@ public class MessageInHandlerTest public void decode_HappyPath_WithParameters() throws Exception { UUID uuid = UUIDGen.getTimeUUID(); - Map<ParameterType, Object> parameters = new HashMap<>(); - parameters.put(ParameterType.FAILURE_REASON, (short)42); + Map<ParameterType, Object> parameters = new EnumMap<>(ParameterType.class); + parameters.put(ParameterType.FAILURE_RESPONSE, MessagingService.ONE_BYTE); + parameters.put(ParameterType.FAILURE_REASON, Shorts.checkedCast(RequestFailureReason.READ_TOO_MANY_TOMBSTONES.code)); parameters.put(ParameterType.TRACE_SESSION, uuid); MessageInWrapper result = decode_HappyPath(parameters); - Assert.assertEquals(2, result.messageIn.parameters.size()); - Assert.assertEquals((short)42, result.messageIn.parameters.get(ParameterType.FAILURE_REASON)); + Assert.assertEquals(3, result.messageIn.parameters.size()); + Assert.assertTrue(result.messageIn.isFailureResponse()); + Assert.assertEquals(RequestFailureReason.READ_TOO_MANY_TOMBSTONES, result.messageIn.getFailureReason()); Assert.assertEquals(uuid, result.messageIn.parameters.get(ParameterType.TRACE_SESSION)); } private MessageInWrapper decode_HappyPath(Map<ParameterType, Object> parameters) throws Exception { - MessageOut msgOut = new MessageOut(MessagingService.Verb.ECHO); + MessageOut msgOut = new MessageOut<>(addr, MessagingService.Verb.ECHO, null, null, ImmutableList.of(), SMALL_MESSAGE); for (Map.Entry<ParameterType, Object> param : parameters.entrySet()) msgOut = msgOut.withParameter(param.getKey(), param.getValue()); serialize(msgOut); MessageInWrapper wrapper = new MessageInWrapper(); - MessageInHandler handler = new MessageInHandler(addr, MSG_VERSION, wrapper.messageConsumer); + BaseMessageInHandler handler = getHandler(addr, messagingVersion, wrapper.messageConsumer); List<Object> out = new ArrayList<>(); handler.decode(null, buf, out); @@ -132,14 +163,15 @@ public class MessageInHandlerTest buf.writeInt(MSG_ID); // this is the id buf.writeInt((int) NanoTimeToCurrentTimeMillis.convert(System.nanoTime())); - msgOut.serialize(new ByteBufDataOutputPlus(buf), MSG_VERSION); + msgOut.serialize(new ByteBufDataOutputPlus(buf), messagingVersion); } @Test public void decode_WithHalfReceivedParameters() throws Exception { - MessageOut msgOut = new MessageOut(MessagingService.Verb.ECHO); - msgOut = msgOut.withParameter(ParameterType.FAILURE_REASON, (short)42); + MessageOut msgOut = new MessageOut<>(addr, MessagingService.Verb.ECHO, null, null, ImmutableList.of(), SMALL_MESSAGE); + UUID uuid = UUIDGen.getTimeUUID(); + msgOut = msgOut.withParameter(ParameterType.TRACE_SESSION, uuid); serialize(msgOut); @@ -149,13 +181,13 @@ public class MessageInHandlerTest buf.writerIndex(originalWriterIndex - 6); MessageInWrapper wrapper = new MessageInWrapper(); - MessageInHandler handler = new MessageInHandler(addr, MSG_VERSION, wrapper.messageConsumer); + BaseMessageInHandler handler = getHandler(addr, messagingVersion, wrapper.messageConsumer); List<Object> out = new ArrayList<>(); handler.decode(null, buf, out); Assert.assertNull(wrapper.messageIn); - MessageHeader header = handler.getMessageHeader(); + BaseMessageInHandler.MessageHeader header = handler.getMessageHeader(); Assert.assertEquals(MSG_ID, header.messageId); Assert.assertEquals(msgOut.verb, header.verb); Assert.assertEquals(msgOut.from, header.from); @@ -171,56 +203,59 @@ public class MessageInHandlerTest @Test public void canReadNextParam_HappyPath() throws IOException { - buildParamBuf(13); - Assert.assertTrue(MessageInHandler.canReadNextParam(buf)); + buildParamBufPre40(13); + Assert.assertTrue(MessageInHandlerPre40.canReadNextParam(buf)); } @Test public void canReadNextParam_OnlyFirstByte() throws IOException { - buildParamBuf(13); + buildParamBufPre40(13); buf.writerIndex(1); - Assert.assertFalse(MessageInHandler.canReadNextParam(buf)); + Assert.assertFalse(MessageInHandlerPre40.canReadNextParam(buf)); } @Test public void canReadNextParam_PartialUTF() throws IOException { - buildParamBuf(13); + buildParamBufPre40(13); buf.writerIndex(5); - Assert.assertFalse(MessageInHandler.canReadNextParam(buf)); + Assert.assertFalse(MessageInHandlerPre40.canReadNextParam(buf)); } @Test public void canReadNextParam_TruncatedValueLength() throws IOException { - buildParamBuf(13); + buildParamBufPre40(13); buf.writerIndex(buf.writerIndex() - 13 - 2); - Assert.assertFalse(MessageInHandler.canReadNextParam(buf)); + Assert.assertFalse(MessageInHandlerPre40.canReadNextParam(buf)); } @Test public void canReadNextParam_MissingLastBytes() throws IOException { - buildParamBuf(13); + buildParamBufPre40(13); buf.writerIndex(buf.writerIndex() - 2); - Assert.assertFalse(MessageInHandler.canReadNextParam(buf)); + Assert.assertFalse(MessageInHandlerPre40.canReadNextParam(buf)); } - private void buildParamBuf(int valueLength) throws IOException + private void buildParamBufPre40(int valueLength) throws IOException { buf = Unpooled.buffer(1024, 1024); // 1k should be enough for everybody! - ByteBufDataOutputPlus output = new ByteBufDataOutputPlus(buf); - output.writeUTF("name"); - byte[] array = new byte[valueLength]; - output.writeInt(array.length); - output.write(array); + + try (ByteBufDataOutputPlus output = new ByteBufDataOutputPlus(buf)) + { + output.writeUTF("name"); + byte[] array = new byte[valueLength]; + output.writeInt(array.length); + output.write(array); + } } @Test public void exceptionHandled() { - MessageInHandler handler = new MessageInHandler(addr, MSG_VERSION, null); + BaseMessageInHandler handler = getHandler(addr, messagingVersion, null); EmbeddedChannel channel = new EmbeddedChannel(handler); Assert.assertTrue(channel.isOpen()); handler.exceptionCaught(channel.pipeline().firstContext(), new EOFException()); http://git-wip-us.apache.org/repos/asf/cassandra/blob/06209037/test/unit/org/apache/cassandra/utils/vint/VIntCodingTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/utils/vint/VIntCodingTest.java b/test/unit/org/apache/cassandra/utils/vint/VIntCodingTest.java index 1db8f9d..2189be3 100644 --- a/test/unit/org/apache/cassandra/utils/vint/VIntCodingTest.java +++ b/test/unit/org/apache/cassandra/utils/vint/VIntCodingTest.java @@ -20,8 +20,13 @@ package org.apache.cassandra.utils.vint; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; +import java.io.IOException; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import org.apache.cassandra.io.util.DataOutputBuffer; +import org.apache.cassandra.net.async.ByteBufDataOutputPlus; + import org.junit.Test; import org.junit.Assert; @@ -82,4 +87,14 @@ public class VIntCodingTest Assert.assertEquals( 1, dob.buffer().remaining()); dob.close(); } + + @Test + public void testByteBufWithNegativeNumber() throws IOException + { + int i = -1231238694; + ByteBuf buf = Unpooled.buffer(8); + VIntCoding.writeUnsignedVInt(i, new ByteBufDataOutputPlus(buf)); + long result = VIntCoding.readUnsignedVInt(buf); + Assert.assertEquals(i, result); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org