http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/src/java/org/apache/cassandra/net/async/ChannelWriter.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/ChannelWriter.java b/src/java/org/apache/cassandra/net/async/ChannelWriter.java new file mode 100644 index 0000000..e984736 --- /dev/null +++ b/src/java/org/apache/cassandra/net/async/ChannelWriter.java @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.IOException; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; + +import com.google.common.annotations.VisibleForTesting; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.MessageSizeEstimator; +import io.netty.handler.timeout.IdleStateEvent; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.Future; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.utils.CoalescingStrategies; +import org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy; + +/** + * Represents a ready and post-handshake channel that can send outbound messages. This class groups a netty channel + * with any other channel-related information we track and, most importantly, handles the details on when the channel is flushed. + * + * <h2>Flushing</h2> + * + * We don't flush to the socket on every message as it's a bit of a performance drag (making the system call, copying + * the buffer, sending out a small packet). Thus, by waiting until we have a decent chunk of data (for some definition + * of 'decent'), we can achieve better efficiency and improved performance (yay!). + * <p> + * When to flush mainly depends on whether we use message coalescing or not (see {@link CoalescingStrategies}). + * <p> + * Note that the callback functions are invoked on the netty event loop, which is (in almost all cases) different + * from the thread that will be invoking {@link #write(QueuedMessage, boolean)}. + * + * <h3>Flushing without coalescing</h3> + * + * When no coalescing is in effect, we want to send new message "right away". However, as said above, flushing after + * every message would be particularly inefficient when there is lots of message in our sending queue, and so in + * practice we want to flush in 2 cases: + * 1) After any message <b>if</b> there is no pending message in the send queue. + * 2) When we've filled up or exceeded the netty outbound buffer (see {@link ChannelOutboundBuffer}) + * <p> + * The second part is relatively simple and handled generically in {@link MessageOutHandler#write(ChannelHandlerContext, Object, ChannelPromise)} [1]. + * The first part however is made a little more complicated by how netty's event loop executes. It is woken up by + * external callers to the channel invoking a flush, via either {@link Channel#flush} or one of the {@link Channel#writeAndFlush} + * methods [2]. So a plain {@link Channel#write} will only queue the message in the channel, and not wake up the event loop. + * <p> + * This means we don't want to simply call {@link Channel#write} as we want the message processed immediately. But we + * also don't want to flush on every message if there is more in the sending queue, so simply calling + * {@link Channel#writeAndFlush} isn't completely appropriate either. In practice, we handle this by calling + * {@link Channel#writeAndFlush} (so the netty event loop <b>does</b> wake up), but we override the flush behavior so + * it actually only flushes if there are no pending messages (see how {@link MessageOutHandler#flush} delegates the flushing + * decision back to this class through {@link #onTriggeredFlush}, and how {@link SimpleChannelWriter} makes this a no-op; + * instead {@link SimpleChannelWriter} flushes after any message if there are no more pending ones in + * {@link #onMessageProcessed}). + * + * <h3>Flushing with coalescing</h3> + * + * The goal of coalescing is to (artificially) delay the flushing of data in order to aggregate even more data before + * sending a group of packets out. So we don't want to flush after messages even if there is no pending messages in the + * sending queue, but we rather want to delegate the decision on when to flush to the {@link CoalescingStrategy}. In + * pratice, when coalescing is enabled we will flush in 2 cases: + * 1) When the coalescing strategies decides that we should. + * 2) When we've filled up or exceeded the netty outbound buffer ({@link ChannelOutboundBuffer}), exactly like in the + * no coalescing case. + * <p> + * The second part is handled exactly like in the no coalescing case, see above. + * The first part is handled by {@link CoalescingChannelWriter#write(QueuedMessage, boolean)}. Whenever a message is sent, we check + * if a flush has been already scheduled by the coalescing strategy. If one has, we're done, otherwise we ask the + * strategy when the next flush should happen and schedule one. + * + *<h2>Message timeouts and retries</h2> + * + * The main outward-facing method is {@link #write(QueuedMessage, boolean)}, where callers pass a + * {@link QueuedMessage}. If a message times out, as defined in {@link QueuedMessage#isTimedOut()}, + * the message listener {@link #handleMessageFuture(Future, QueuedMessage, boolean)} is invoked + * with the cause being a {@link ExpiredException}. The message is not retried and it is dropped on the floor. + * <p> + * If there is some {@link IOException} on the socket after the message has been written to the netty channel, + * the message listener {@link #handleMessageFuture(Future, QueuedMessage, boolean)} is invoked + * and 1) we check to see if the connection should be re-established, and 2) possibly createRetry the message. + * + * <h2>Failures</h2> + * + * <h3>Failure to make progress sending bytes</h3> + * If we are unable to make progress sending messages, we'll receive a netty notification + * ({@link IdleStateEvent}) at {@link MessageOutHandler#userEventTriggered(ChannelHandlerContext, Object)}. + * We then want to close the socket/channel, and purge any messages in {@link OutboundMessagingConnection#backlog} + * to try to free up memory as quickly as possible. Any messages in the netty pipeline will be marked as fail + * (as we close the channel), but {@link MessageOutHandler#userEventTriggered(ChannelHandlerContext, Object)} also + * sets a channel attribute, {@link #PURGE_MESSAGES_CHANNEL_ATTR} to true. This is essentially as special flag + * that we can look at in the promise handler code ({@link #handleMessageFuture(Future, QueuedMessage, boolean)}) + * to indicate that any backlog should be thrown away. + * + * <h2>Notes</h2> + * [1] For those desperately interested, and only after you've read the entire class-level doc: You can register a custom + * {@link MessageSizeEstimator} with a netty channel. When a message is written to the channel, it will check the + * message size, and if the max ({@link ChannelOutboundBuffer}) size will be exceeded, a task to signal the "channel + * writability changed" will be executed in the channel. That task, however, will wake up the event loop. + * Thus if coalescing is enabled, the event loop will wake up prematurely and process (and possibly flush!) the messages + * currently in the queue, thus defeating an aspect of coalescing. Hence, we're not using that feature of netty. + * [2]: The netty event loop is also woken up by it's internal timeout on the epoll_wait() system call. + */ +abstract class ChannelWriter +{ + /** + * A netty channel {@link Attribute} to indicate, when a channel is closed, any backlogged messages should be purged, + * as well. See the class-level documentation for more information. + */ + static final AttributeKey<Boolean> PURGE_MESSAGES_CHANNEL_ATTR = AttributeKey.newInstance("purgeMessages"); + + protected final Channel channel; + private volatile boolean closed; + + /** Number of currently pending messages on this channel. */ + final AtomicLong pendingMessageCount = new AtomicLong(0); + + /** + * A consuming function that handles the result of each message sent. + */ + private final Consumer<MessageResult> messageResultConsumer; + + /** + * A reusable instance to avoid creating garbage on preciessing the result of every message sent. + * As we have the guarantee that the netty evet loop is single threaded, there should be no contention over this + * instance, as long as it (not it's state) is shared across threads. + */ + private final MessageResult messageResult = new MessageResult(); + + protected ChannelWriter(Channel channel, Consumer<MessageResult> messageResultConsumer) + { + this.channel = channel; + this.messageResultConsumer = messageResultConsumer; + channel.attr(PURGE_MESSAGES_CHANNEL_ATTR).set(false); + } + + /** + * Creates a new {@link ChannelWriter} using the (assumed properly connected) provided channel, and using coalescing + * based on the provided strategy. + */ + static ChannelWriter create(Channel channel, Consumer<MessageResult> messageResultConsumer, Optional<CoalescingStrategy> coalescingStrategy) + { + return coalescingStrategy.isPresent() + ? new CoalescingChannelWriter(channel, messageResultConsumer, coalescingStrategy.get()) + : new SimpleChannelWriter(channel, messageResultConsumer); + } + + /** + * Writes a message to this {@link ChannelWriter} if the channel is writable. + * <p> + * We always want to write to the channel *unless* it's not writable yet still open. + * If the channel is closed, the promise will be notifed as a fail (due to channel closed), + * and let the handler ({@link #handleMessageFuture(Future, QueuedMessage, boolean)}) + * do the reconnect magic/dance. Thus we simplify when to reconnect by not burdening the (concurrent) callers + * of this method, and instead keep it all in the future handler/event loop (which is single threaded). + * + * @param message the message to write/send. + * @param checkWritability a flag to indicate if the status of the channel should be checked before passing + * the message on to the {@link #channel}. + * @return true if the message was written to the channel; else, false. + */ + boolean write(QueuedMessage message, boolean checkWritability) + { + if ( (checkWritability && (channel.isWritable()) || !channel.isOpen()) || !checkWritability) + { + write0(message).addListener(f -> handleMessageFuture(f, message, true)); + return true; + } + return false; + } + + /** + * Handles the future of sending a particular message on this {@link ChannelWriter}. + * <p> + * Note: this is called from the netty event loop, so there is no race across multiple execution of this method. + */ + @VisibleForTesting + void handleMessageFuture(Future<? super Void> future, QueuedMessage msg, boolean allowReconnect) + { + messageResult.setAll(this, msg, future, allowReconnect); + messageResultConsumer.accept(messageResult); + messageResult.clearAll(); + } + + boolean shouldPurgeBacklog() + { + if (!channel.attr(PURGE_MESSAGES_CHANNEL_ATTR).get()) + return false; + + channel.attr(PURGE_MESSAGES_CHANNEL_ATTR).set(false); + return true; + } + + /** + * Writes a backlog of message to this {@link ChannelWriter}. This is mostly equivalent to calling + * {@link #write(QueuedMessage, boolean)} for every message of the provided backlog queue, but + * it ignores any coalescing, triggering a flush only once after all messages have been sent. + * + * @param backlog the backlog of message to send. + * @return the count of items written to the channel from the queue. + */ + int writeBacklog(Queue<QueuedMessage> backlog, boolean allowReconnect) + { + int count = 0; + while (true) + { + if (!channel.isWritable()) + break; + + QueuedMessage msg = backlog.poll(); + if (msg == null) + break; + + pendingMessageCount.incrementAndGet(); + ChannelFuture future = channel.write(msg); + future.addListener(f -> handleMessageFuture(f, msg, allowReconnect)); + count++; + } + + // as this is an infrequent operation, don't bother coordinating with the instance-level flush task + if (count > 0) + channel.flush(); + + return count; + } + + void close() + { + if (closed) + return; + + closed = true; + channel.close(); + } + + long pendingMessageCount() + { + return pendingMessageCount.get(); + } + + /** + * Close the underlying channel but only after having make sure every pending message has been properly sent. + */ + void softClose() + { + if (closed) + return; + + closed = true; + channel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); + } + + @VisibleForTesting + boolean isClosed() + { + return closed; + } + + /** + * Write the message to the {@link #channel}. + * <p> + * Note: this method, in almost all cases, is invoked from an app-level writing thread, not the netty event loop. + */ + protected abstract ChannelFuture write0(QueuedMessage message); + + /** + * Invoked after a message has been processed in the pipeline. Should only be used for essential bookkeeping operations. + * <p> + * Note: this method is invoked on the netty event loop. + */ + abstract void onMessageProcessed(ChannelHandlerContext ctx); + + /** + * Invoked when pipeline receives a flush request. + * <p> + * Note: this method is invoked on the netty event loop. + */ + abstract void onTriggeredFlush(ChannelHandlerContext ctx); + + /** + * Handles the non-coalescing flush case. + */ + @VisibleForTesting + static class SimpleChannelWriter extends ChannelWriter + { + private SimpleChannelWriter(Channel channel, Consumer<MessageResult> messageResultConsumer) + { + super(channel, messageResultConsumer); + } + + protected ChannelFuture write0(QueuedMessage message) + { + pendingMessageCount.incrementAndGet(); + // We don't truly want to flush on every message but we do want to wake-up the netty event loop for the + // channel so the message is processed right away, which is why we use writeAndFlush. This won't actually + // flush, though, because onTriggeredFlush, which MessageOutHandler delegates to, does nothing. We will + // flush after the message is processed though if there is no pending one due to onMessageProcessed. + // See the class javadoc for context and much more details. + return channel.writeAndFlush(message); + } + + void onMessageProcessed(ChannelHandlerContext ctx) + { + if (pendingMessageCount.decrementAndGet() == 0) + ctx.flush(); + } + + void onTriggeredFlush(ChannelHandlerContext ctx) + { + // Don't actually flush on "normal" flush calls to the channel. + } + } + + /** + * Handles the coalescing flush case. + */ + @VisibleForTesting + static class CoalescingChannelWriter extends ChannelWriter + { + private static final int MIN_MESSAGES_FOR_COALESCE = DatabaseDescriptor.getOtcCoalescingEnoughCoalescedMessages(); + + private final CoalescingStrategy strategy; + private final int minMessagesForCoalesce; + + @VisibleForTesting + final AtomicBoolean scheduledFlush = new AtomicBoolean(false); + + CoalescingChannelWriter(Channel channel, Consumer<MessageResult> messageResultConsumer, CoalescingStrategy strategy) + { + this (channel, messageResultConsumer, strategy, MIN_MESSAGES_FOR_COALESCE); + } + + @VisibleForTesting + CoalescingChannelWriter(Channel channel, Consumer<MessageResult> messageResultConsumer, CoalescingStrategy strategy, int minMessagesForCoalesce) + { + super(channel, messageResultConsumer); + this.strategy = strategy; + this.minMessagesForCoalesce = minMessagesForCoalesce; + } + + protected ChannelFuture write0(QueuedMessage message) + { + long pendingCount = pendingMessageCount.incrementAndGet(); + ChannelFuture future = channel.write(message); + strategy.newArrival(message); + + // if we lost the race to set the state, simply write to the channel (no flush) + if (!scheduledFlush.compareAndSet(false, true)) + return future; + + long flushDelayNanos; + // if we've hit the minimum number of messages for coalescing or we've run out of coalesce time, flush. + // note: we check the exact count, instead of greater than or equal to, of message here to prevent a flush task + // for each message (if there's messages coming in on multiple threads). There will be, of course, races + // with the consumer decrementing the pending counter, but that's still less excessive flushes. + if (pendingCount == minMessagesForCoalesce || (flushDelayNanos = strategy.currentCoalescingTimeNanos()) <= 0) + { + scheduledFlush.set(false); + channel.flush(); + } + else + { + // calling schedule() on the eventLoop will force it to wake up (if not already executing) and schedule the task + channel.eventLoop().schedule(() -> { + // NOTE: this executes on the event loop + scheduledFlush.set(false); + // we execute() the flush() as an additional task rather than immediately in-line as there is a + // race condition when this task runs (executing on the event loop) and a thread that writes the channel (top of this method). + // If this task is picked up but before the scheduledFlush falg is flipped, the other thread writes + // and then checks the scheduledFlush (which is still true) and exits. + // This task changes the flag and if it calls flush() in-line, and netty flushs everything immediately (that is, what's been serialized) + // to the transport as we're on the event loop. The other thread's write became a task that executes *after* this task in the netty queue, + // and if there's not a subsequent followup flush scheduled, that write can be orphaned until another write comes in. + channel.eventLoop().execute(channel::flush); + }, flushDelayNanos, TimeUnit.NANOSECONDS); + } + return future; + } + + void onMessageProcessed(ChannelHandlerContext ctx) + { + pendingMessageCount.decrementAndGet(); + } + + void onTriggeredFlush(ChannelHandlerContext ctx) + { + // When coalescing, obey the flush calls normally + ctx.flush(); + } + } +}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/src/java/org/apache/cassandra/net/async/ExpiredException.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/ExpiredException.java b/src/java/org/apache/cassandra/net/async/ExpiredException.java new file mode 100644 index 0000000..191900c --- /dev/null +++ b/src/java/org/apache/cassandra/net/async/ExpiredException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +/** + * Thrown when a {@link QueuedMessage} has timed out (has sat in the netty outbound channel for too long). + */ +class ExpiredException extends Exception +{ + @SuppressWarnings("ThrowableInstanceNeverThrown") + static final ExpiredException INSTANCE = new ExpiredException(); +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/src/java/org/apache/cassandra/net/async/HandshakeProtocol.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/HandshakeProtocol.java b/src/java/org/apache/cassandra/net/async/HandshakeProtocol.java new file mode 100644 index 0000000..9b8df80 --- /dev/null +++ b/src/java/org/apache/cassandra/net/async/HandshakeProtocol.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net.async; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.net.InetAddress; +import java.util.Objects; + +import com.google.common.annotations.VisibleForTesting; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.ByteBufOutputStream; +import org.apache.cassandra.net.CompactEndpointSerializationHelper; +import org.apache.cassandra.net.MessagingService; + +/** + * Messages for the handshake phase of the internode protocol. + * <p> + * The handshake's main purpose is to establish a protocol version that both side can talk, as well as exchanging a few connection + * options/parameters. The handshake is composed of 3 messages, the first being sent by the initiator of the connection. The other + * side then answer with the 2nd message. At that point, if a version mismatch is detected by the connection initiator, + * it will simply disconnect and reconnect with a more appropriate version. But if the version is acceptable, the connection + * initiator sends the third message of the protocol, after which it considers the connection ready. + * <p> + * See below for a more precise description of each of those 3 messages. + * <p> + * Note that this handshake protocol doesn't fully apply to streaming. For streaming, only the first message is sent, + * after which the streaming protocol takes over (not documented here) + */ +public class HandshakeProtocol +{ + /** + * The initial message sent when a node creates a new connection to a remote peer. This message contains: + * 1) the {@link MessagingService#PROTOCOL_MAGIC} number (4 bytes). + * 2) the connection flags (4 bytes), which encodes: + * - the version the initiator thinks should be used for the connection (in practice, either the initiator + * version if it's the first time we connect to that remote since startup, or the last version known for that + * peer otherwise). + * - the "mode" of the connection: whether it is for streaming or for messaging. + * - whether compression should be used or not (if it is, compression is enabled _after_ the last message of the + * handshake has been sent). + * <p> + * More precisely, connection flags: + * <pre> + * {@code + * 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |U U C M | | | + * |N N M O | VERSION | unused | + * |U U P D | | | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * } + * </pre> + * UNU - unused bits lowest two bits; from a historical note: used to be "serializer type," which was always Binary + * CMP - compression enabled bit + * MOD - connection mode. If the bit is on, the connection is for streaming; if the bit is off, it is for inter-node messaging. + * VERSION - if a streaming connection, indicates the streaming protocol version {@link org.apache.cassandra.streaming.messages.StreamMessage#CURRENT_VERSION}; + * if a messaging connection, indicates the messaging protocol version the initiator *thinks* should be used. + */ + public static class FirstHandshakeMessage + { + /** Contains the PROTOCOL_MAGIC (int) and the flags (int). */ + private static final int LENGTH = 8; + + final int messagingVersion; + final NettyFactory.Mode mode; + final boolean compressionEnabled; + + public FirstHandshakeMessage(int messagingVersion, NettyFactory.Mode mode, boolean compressionEnabled) + { + assert messagingVersion > 0; + this.messagingVersion = messagingVersion; + this.mode = mode; + this.compressionEnabled = compressionEnabled; + } + + @VisibleForTesting + int encodeFlags() + { + int flags = 0; + if (compressionEnabled) + flags |= 1 << 2; + if (mode == NettyFactory.Mode.STREAMING) + flags |= 1 << 3; + + flags |= (messagingVersion << 8); + return flags; + } + + public ByteBuf encode(ByteBufAllocator allocator) + { + ByteBuf buffer = allocator.directBuffer(LENGTH, LENGTH); + buffer.writerIndex(0); + buffer.writeInt(MessagingService.PROTOCOL_MAGIC); + buffer.writeInt(encodeFlags()); + return buffer; + } + + static FirstHandshakeMessage maybeDecode(ByteBuf in) throws IOException + { + if (in.readableBytes() < LENGTH) + return null; + + MessagingService.validateMagic(in.readInt()); + int flags = in.readInt(); + int version = MessagingService.getBits(flags, 15, 8); + NettyFactory.Mode mode = MessagingService.getBits(flags, 3, 1) == 1 + ? NettyFactory.Mode.STREAMING + : NettyFactory.Mode.MESSAGING; + boolean compressed = MessagingService.getBits(flags, 2, 1) == 1; + return new FirstHandshakeMessage(version, mode, compressed); + } + + @Override + public boolean equals(Object other) + { + if (!(other instanceof FirstHandshakeMessage)) + return false; + + FirstHandshakeMessage that = (FirstHandshakeMessage)other; + return this.messagingVersion == that.messagingVersion + && this.mode == that.mode + && this.compressionEnabled == that.compressionEnabled; + } + + @Override + public int hashCode() + { + return Objects.hash(messagingVersion, mode, compressionEnabled); + } + + @Override + public String toString() + { + return String.format("FirstHandshakeMessage - messaging version: %d, mode: %s, compress: %b", messagingVersion, mode, compressionEnabled); + } + } + + /** + * The second message of the handshake, sent by the node receiving the {@link FirstHandshakeMessage} back to the + * connection initiator. This message contains the messaging version of the peer sending this message, + * so {@link org.apache.cassandra.net.MessagingService#current_version}. + */ + static class SecondHandshakeMessage + { + /** The messaging version sent by the receiving peer (int). */ + private static final int LENGTH = 4; + + final int messagingVersion; + + SecondHandshakeMessage(int messagingVersion) + { + this.messagingVersion = messagingVersion; + } + + public ByteBuf encode(ByteBufAllocator allocator) + { + ByteBuf buffer = allocator.directBuffer(LENGTH, LENGTH); + buffer.writerIndex(0); + buffer.writeInt(messagingVersion); + return buffer; + } + + static SecondHandshakeMessage maybeDecode(ByteBuf in) + { + return in.readableBytes() >= LENGTH ? new SecondHandshakeMessage(in.readInt()) : null; + } + + @Override + public boolean equals(Object other) + { + return other instanceof SecondHandshakeMessage + && this.messagingVersion == ((SecondHandshakeMessage) other).messagingVersion; + } + + @Override + public int hashCode() + { + return Integer.hashCode(messagingVersion); + } + + @Override + public String toString() + { + return String.format("SecondHandshakeMessage - messaging version: %d", messagingVersion); + } + } + + /** + * The third message of the handshake, sent by the connection initiator on reception of {@link SecondHandshakeMessage}. + * This message contains: + * 1) the connection initiator's messaging version (4 bytes) - {@link org.apache.cassandra.net.MessagingService#current_version}. + * 2) the connection initiator's broadcast address as encoded by {@link org.apache.cassandra.net.CompactEndpointSerializationHelper}. + * This can be either 5 bytes for an IPv4 address, or 17 bytes for an IPv6 one. + * <p> + * This message concludes the handshake protocol. After that, the connection will used either for streaming, or to + * send messages. If the connection is to be compressed, compression is enabled only after this message is sent/received. + */ + static class ThirdHandshakeMessage + { + /** + * The third message contains the version and IP address of the sending node. Because the IP can be either IPv4 or + * IPv6, this can be either 9 (4 for version + 5 for IP) or 21 (4 for version + 17 for IP) bytes. Since we can't know + * a priori if the IP address will be v4 or v6, go with the minimum required bytes and hope that if the address is + * v6, we'll have the extra 12 bytes in the packet. + */ + private static final int MIN_LENGTH = 9; + + final int messagingVersion; + final InetAddress address; + + ThirdHandshakeMessage(int messagingVersion, InetAddress address) + { + this.messagingVersion = messagingVersion; + this.address = address; + } + + @SuppressWarnings("resource") + public ByteBuf encode(ByteBufAllocator allocator) + { + int bufLength = Integer.BYTES + CompactEndpointSerializationHelper.serializedSize(address); + ByteBuf buffer = allocator.directBuffer(bufLength, bufLength); + buffer.writerIndex(0); + buffer.writeInt(messagingVersion); + try + { + DataOutput bbos = new ByteBufOutputStream(buffer); + CompactEndpointSerializationHelper.serialize(address, bbos); + return buffer; + } + catch (IOException e) + { + // Shouldn't happen, we're serializing in memory. + throw new AssertionError(e); + } + } + + @SuppressWarnings("resource") + static ThirdHandshakeMessage maybeDecode(ByteBuf in) + { + if (in.readableBytes() < MIN_LENGTH) + return null; + + in.markReaderIndex(); + int version = in.readInt(); + DataInput inputStream = new ByteBufInputStream(in); + try + { + InetAddress address = CompactEndpointSerializationHelper.deserialize(inputStream); + return new ThirdHandshakeMessage(version, address); + } + catch (IOException e) + { + // makes the assumption we didn't have enough bytes to deserialize an IPv6 address, + // as we only check the MIN_LENGTH of the buf. + in.resetReaderIndex(); + return null; + } + } + + @Override + public boolean equals(Object other) + { + if (!(other instanceof ThirdHandshakeMessage)) + return false; + + ThirdHandshakeMessage that = (ThirdHandshakeMessage)other; + return this.messagingVersion == that.messagingVersion + && Objects.equals(this.address, that.address); + } + + @Override + public int hashCode() + { + return Objects.hash(messagingVersion, address); + } + + @Override + public String toString() + { + return String.format("ThirdHandshakeMessage - messaging version: %d, address = %s", messagingVersion, address); + } + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java b/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java new file mode 100644 index 0000000..5ea03dc --- /dev/null +++ b/src/java/org/apache/cassandra/net/async/InboundHandshakeHandler.java @@ -0,0 +1,293 @@ +package org.apache.cassandra.net.async; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLSession; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.ssl.SslHandler; +import org.apache.cassandra.auth.IInternodeAuthenticator; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.async.HandshakeProtocol.FirstHandshakeMessage; +import org.apache.cassandra.net.async.HandshakeProtocol.SecondHandshakeMessage; +import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage; + +/** + * 'Server'-side component that negotiates the internode handshake when establishing a new connection. + * This handler will be the first in the netty channel for each incoming connection (secure socket (TLS) notwithstanding), + * and once the handshake is successful, it will configure the proper handlers (mostly {@link MessageInHandler}) + * and remove itself from the working pipeline. + */ +class InboundHandshakeHandler extends ByteToMessageDecoder +{ + private static final Logger logger = LoggerFactory.getLogger(NettyFactory.class); + + enum State { START, AWAITING_HANDSHAKE_BEGIN, AWAIT_STREAM_START_RESPONSE, AWAIT_MESSAGING_START_RESPONSE, MESSAGING_HANDSHAKE_COMPLETE, HANDSHAKE_FAIL } + + private State state; + + private final IInternodeAuthenticator authenticator; + private boolean hasAuthenticated; + + /** + * The peer's declared messaging version. + */ + private int version; + + /** + * Does the peer support (or want to use) compressed data? + */ + private boolean compressed; + + /** + * A future the essentially places a timeout on how long we'll wait for the peer + * to complete the next step of the handshake. + */ + private Future<?> handshakeTimeout; + + InboundHandshakeHandler(IInternodeAuthenticator authenticator) + { + this.authenticator = authenticator; + state = State.START; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) + { + try + { + if (!hasAuthenticated) + { + logSecureSocketDetails(ctx); + if (!handleAuthenticate(ctx.channel().remoteAddress(), ctx)) + return; + } + + switch (state) + { + case START: + state = handleStart(ctx, in); + break; + case AWAIT_MESSAGING_START_RESPONSE: + state = handleMessagingStartResponse(ctx, in); + break; + case HANDSHAKE_FAIL: + throw new IllegalStateException("channel should be closed after determining the handshake failed with peer: " + ctx.channel().remoteAddress()); + default: + logger.error("unhandled state: " + state); + state = State.HANDSHAKE_FAIL; + ctx.close(); + } + } + catch (Exception e) + { + logger.error("unexpected error while negotiating internode messaging handshake", e); + state = State.HANDSHAKE_FAIL; + ctx.close(); + } + } + + /** + * Ensure the peer is allowed to connect to this node. + */ + @VisibleForTesting + boolean handleAuthenticate(SocketAddress socketAddress, ChannelHandlerContext ctx) + { + // the only reason addr would not be instanceof InetSocketAddress is in unit testing, when netty's EmbeddedChannel + // uses EmbeddedSocketAddress. Normally, we'd do an instanceof for that class name, but it's marked with default visibility, + // so we can't reference it outside of it's package (and so it doesn't compile). + if (socketAddress instanceof InetSocketAddress) + { + InetSocketAddress addr = (InetSocketAddress)socketAddress; + if (!authenticator.authenticate(addr.getAddress(), addr.getPort())) + { + if (logger.isTraceEnabled()) + logger.trace("Failed to authenticate peer {}", addr); + ctx.close(); + return false; + } + } + else if (!socketAddress.getClass().getSimpleName().equals("EmbeddedSocketAddress")) + { + ctx.close(); + return false; + } + hasAuthenticated = true; + return true; + } + + /** + * If the connection is using SSL/TLS, log some details about it. + */ + private void logSecureSocketDetails(ChannelHandlerContext ctx) + { + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + if (sslHandler != null) + { + SSLSession session = sslHandler.engine().getSession(); + logger.info("connection from peer {}, protocol = {}, cipher suite = {}", + ctx.channel().remoteAddress(), session.getProtocol(), session.getCipherSuite()); + } + } + + /** + * Handles receiving the first message in the internode messaging handshake protocol. If the sender's protocol version + * is accepted, we respond with the second message of the handshake protocol. + */ + @VisibleForTesting + State handleStart(ChannelHandlerContext ctx, ByteBuf in) throws IOException + { + FirstHandshakeMessage msg = FirstHandshakeMessage.maybeDecode(in); + if (msg == null) + return State.START; + + logger.trace("received first handshake message from peer {}, message = {}", ctx.channel().remoteAddress(), msg); + version = msg.messagingVersion; + + if (msg.mode == NettyFactory.Mode.STREAMING) + { + // TODO fill in once streaming is moved to netty + ctx.close(); + return State.AWAIT_STREAM_START_RESPONSE; + } + else + { + if (version < MessagingService.VERSION_30) + { + logger.error("Unable to read obsolete message version {} from {}; The earliest version supported is 3.0.0", version, ctx.channel().remoteAddress()); + ctx.close(); + return State.HANDSHAKE_FAIL; + } + + logger.trace("Connection version {} from {}", version, ctx.channel().remoteAddress()); + compressed = msg.compressionEnabled; + + // if this version is < the MS version the other node is trying + // to connect with, the other node will disconnect + ctx.writeAndFlush(new SecondHandshakeMessage(MessagingService.current_version).encode(ctx.alloc())) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + + // outbound side will reconnect to change the version + if (version > MessagingService.current_version) + { + logger.info("peer wants to use a messaging version higher ({}) than what this node supports ({})", version, MessagingService.current_version); + ctx.close(); + return State.HANDSHAKE_FAIL; + } + + long timeout = TimeUnit.MILLISECONDS.toNanos(DatabaseDescriptor.getRpcTimeout()); + handshakeTimeout = ctx.executor().schedule(() -> failHandshake(ctx), timeout, TimeUnit.MILLISECONDS); + return State.AWAIT_MESSAGING_START_RESPONSE; + } + } + + /** + * Handles the third (and last) message in the internode messaging handshake protocol. Grabs the protocol version and + * IP addr the peer wants to use. + */ + @VisibleForTesting + State handleMessagingStartResponse(ChannelHandlerContext ctx, ByteBuf in) throws IOException + { + ThirdHandshakeMessage msg = ThirdHandshakeMessage.maybeDecode(in); + if (msg == null) + return State.AWAIT_MESSAGING_START_RESPONSE; + + logger.trace("received third handshake message from peer {}, message = {}", ctx.channel().remoteAddress(), msg); + if (handshakeTimeout != null) + { + handshakeTimeout.cancel(false); + handshakeTimeout = null; + } + + int maxVersion = msg.messagingVersion; + if (maxVersion > MessagingService.current_version) + { + logger.error("peer wants to use a messaging version higher ({}) than what this node supports ({})", maxVersion, MessagingService.current_version); + ctx.close(); + return State.HANDSHAKE_FAIL; + } + + // record the (true) version of the endpoint + InetAddress from = msg.address; + MessagingService.instance().setVersion(from, maxVersion); + logger.trace("Set version for {} to {} (will use {})", from, maxVersion, MessagingService.instance().getVersion(from)); + + setupMessagingPipeline(ctx.pipeline(), from, compressed, version); + return State.MESSAGING_HANDSHAKE_COMPLETE; + } + + @VisibleForTesting + void setupMessagingPipeline(ChannelPipeline pipeline, InetAddress peer, boolean compressed, int messagingVersion) + { + if (compressed) + pipeline.addLast(NettyFactory.INBOUND_COMPRESSOR_HANDLER_NAME, NettyFactory.createLz4Decoder(messagingVersion)); + + pipeline.addLast("messageInHandler", new MessageInHandler(peer, messagingVersion)); + pipeline.remove(this); + } + + @VisibleForTesting + void failHandshake(ChannelHandlerContext ctx) + { + // we're not really racing on the handshakeTimeout as we're in the event loop, + // but, hey, defensive programming is beautiful thing! + if (state == State.MESSAGING_HANDSHAKE_COMPLETE || (handshakeTimeout != null && handshakeTimeout.isCancelled())) + return; + + state = State.HANDSHAKE_FAIL; + ctx.close(); + + if (handshakeTimeout != null) + { + handshakeTimeout.cancel(false); + handshakeTimeout = null; + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) + { + logger.trace("Failed to properly handshake with peer {}. Closing the channel.", ctx.channel().remoteAddress()); + failHandshake(ctx); + ctx.fireChannelInactive(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + { + logger.error("Failed to properly handshake with peer {}. Closing the channel.", ctx.channel().remoteAddress(), cause); + failHandshake(ctx); + } + + @VisibleForTesting + public State getState() + { + return state; + } + + @VisibleForTesting + public void setState(State nextState) + { + state = nextState; + } + + @VisibleForTesting + void setHandshakeTimeout(Future<?> timeout) + { + handshakeTimeout = timeout; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/src/java/org/apache/cassandra/net/async/MessageInHandler.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/MessageInHandler.java b/src/java/org/apache/cassandra/net/async/MessageInHandler.java new file mode 100644 index 0000000..b400512 --- /dev/null +++ b/src/java/org/apache/cassandra/net/async/MessageInHandler.java @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import org.apache.cassandra.db.monitoring.ApproximateTime; +import org.apache.cassandra.exceptions.UnknownTableException; +import org.apache.cassandra.net.CompactEndpointSerializationHelper; +import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.MessagingService; + +/** + * Parses out individual messages from the incoming buffers. Each message, both header and payload, is incrementally built up + * from the available input data, then passed to the {@link #messageConsumer}. + * + * Note: this class derives from {@link ByteToMessageDecoder} to take advantage of the {@link ByteToMessageDecoder.Cumulator} + * behavior across {@link #decode(ChannelHandlerContext, ByteBuf, List)} invocations. That way we don't have to maintain + * the not-fully consumed {@link ByteBuf}s. + */ +class MessageInHandler extends ByteToMessageDecoder +{ + public static final Logger logger = LoggerFactory.getLogger(MessageInHandler.class); + + /** + * The default target for consuming deserialized {@link MessageIn}. + */ + static final BiConsumer<MessageIn, Integer> MESSAGING_SERVICE_CONSUMER = (messageIn, id) -> MessagingService.instance().receive(messageIn, id); + + private enum State + { + READ_FIRST_CHUNK, + READ_IP_ADDRESS, + READ_SECOND_CHUNK, + READ_PARAMETERS_DATA, + READ_PAYLOAD_SIZE, + READ_PAYLOAD + } + + /** + * The byte count for magic, msg id, timestamp values. + */ + @VisibleForTesting + static final int FIRST_SECTION_BYTE_COUNT = 12; + + /** + * The byte count for the verb id and the number of parameters. + */ + private static final int SECOND_SECTION_BYTE_COUNT = 8; + + private final InetAddress peer; + private final int messagingVersion; + + /** + * Abstracts out depending directly on {@link MessagingService#receive(MessageIn, int)}; this makes tests more sane + * as they don't require nor trigger the entire message processing circus. + */ + private final BiConsumer<MessageIn, Integer> messageConsumer; + + private State state; + private MessageHeader messageHeader; + + MessageInHandler(InetAddress peer, int messagingVersion) + { + this (peer, messagingVersion, MESSAGING_SERVICE_CONSUMER); + } + + MessageInHandler(InetAddress peer, int messagingVersion, BiConsumer<MessageIn, Integer> messageConsumer) + { + this.peer = peer; + this.messagingVersion = messagingVersion; + this.messageConsumer = messageConsumer; + state = State.READ_FIRST_CHUNK; + } + + /** + * For each new message coming in, builds up a {@link MessageHeader} instance incrementally. This method + * attempts to deserialize as much header information as it can out of the incoming {@link ByteBuf}, and + * maintains a trivial state machine to remember progress across invocations. + */ + @SuppressWarnings("resource") + public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) + { + ByteBufDataInputPlus inputPlus = new ByteBufDataInputPlus(in); + try + { + while (true) + { + // an imperfect optimization around calling in.readableBytes() all the time + int readableBytes = in.readableBytes(); + + switch (state) + { + case READ_FIRST_CHUNK: + if (readableBytes < FIRST_SECTION_BYTE_COUNT) + return; + MessagingService.validateMagic(in.readInt()); + messageHeader = new MessageHeader(); + messageHeader.messageId = in.readInt(); + int messageTimestamp = in.readInt(); // make sure to read the sent timestamp, even if DatabaseDescriptor.hasCrossNodeTimeout() is not enabled + messageHeader.constructionTime = MessageIn.deriveConstructionTime(peer, messageTimestamp, ApproximateTime.currentTimeMillis()); + state = State.READ_IP_ADDRESS; + readableBytes -= FIRST_SECTION_BYTE_COUNT; + // fall-through + case READ_IP_ADDRESS: + // unfortunately, this assumes knowledge of how CompactEndpointSerializationHelper serializes data (the first byte is the size). + // first, check that we can actually read the size byte, then check if we can read that number of bytes. + // the "+ 1" is to make sure we have the size byte in addition to the serialized IP addr count of bytes in the buffer. + int serializedAddrSize; + if (readableBytes < 1 || readableBytes < (serializedAddrSize = in.getByte(in.readerIndex()) + 1)) + return; + messageHeader.from = CompactEndpointSerializationHelper.deserialize(inputPlus); + state = State.READ_SECOND_CHUNK; + readableBytes -= serializedAddrSize; + // fall-through + case READ_SECOND_CHUNK: + if (readableBytes < SECOND_SECTION_BYTE_COUNT) + return; + messageHeader.verb = MessagingService.Verb.fromId(in.readInt()); + int paramCount = in.readInt(); + messageHeader.parameterCount = paramCount; + messageHeader.parameters = paramCount == 0 ? Collections.emptyMap() : new HashMap<>(); + state = State.READ_PARAMETERS_DATA; + readableBytes -= SECOND_SECTION_BYTE_COUNT; + // fall-through + case READ_PARAMETERS_DATA: + if (messageHeader.parameterCount > 0) + { + if (!readParameters(in, inputPlus, messageHeader.parameterCount, messageHeader.parameters)) + return; + readableBytes = in.readableBytes(); // we read an indeterminate number of bytes for the headers, so just ask the buffer again + } + state = State.READ_PAYLOAD_SIZE; + // fall-through + case READ_PAYLOAD_SIZE: + if (readableBytes < 4) + return; + messageHeader.payloadSize = in.readInt(); + state = State.READ_PAYLOAD; + readableBytes -= 4; + // fall-through + case READ_PAYLOAD: + if (readableBytes < messageHeader.payloadSize) + return; + + // TODO consider deserailizing the messge not on the event loop + MessageIn<Object> messageIn = MessageIn.read(inputPlus, messagingVersion, + messageHeader.messageId, messageHeader.constructionTime, messageHeader.from, + messageHeader.payloadSize, messageHeader.verb, messageHeader.parameters); + + if (messageIn != null) + messageConsumer.accept(messageIn, messageHeader.messageId); + + state = State.READ_FIRST_CHUNK; + messageHeader = null; + break; + default: + throw new IllegalStateException("unknown/unhandled state: " + state); + } + } + } + catch (Exception e) + { + exceptionCaught(ctx, e); + } + } + + /** + * @return <code>true</code> if all the parameters have been read from the {@link ByteBuf}; else, <code>false</code>. + */ + private boolean readParameters(ByteBuf in, ByteBufDataInputPlus inputPlus, int parameterCount, Map<String, byte[]> parameters) throws IOException + { + // makes the assumption that map.size() is a constant time function (HashMap.size() is) + while (parameters.size() < parameterCount) + { + if (!canReadNextParam(in)) + return false; + + String key = DataInputStream.readUTF(inputPlus); + byte[] value = new byte[in.readInt()]; + in.readBytes(value); + parameters.put(key, value); + } + + return true; + } + + /** + * Determine if we can read the next parameter from the {@link ByteBuf}. This method will *always* set the {@code in} + * readIndex back to where it was when this method was invoked. + * + * NOTE: this function would be sooo much simpler if we included a parameters length int in the messaging format, + * instead of checking the remaining readable bytes for each field as we're parsing it. c'est la vie ... + */ + @VisibleForTesting + static boolean canReadNextParam(ByteBuf in) + { + in.markReaderIndex(); + // capture the readableBytes value here to avoid all the virtual function calls. + // subtract 6 as we know we'll be reading a short and an int (for the utf and value lengths). + final int minimumBytesRequired = 6; + int readableBytes = in.readableBytes() - minimumBytesRequired; + if (readableBytes < 0) + return false; + + // this is a tad invasive, but since we know the UTF string is prefaced with a 2-byte length, + // read that to make sure we have enough bytes to read the string itself. + short strLen = in.readShort(); + // check if we can read that many bytes for the UTF + if (strLen > readableBytes) + { + in.resetReaderIndex(); + return false; + } + in.skipBytes(strLen); + readableBytes -= strLen; + + // check if we can read the value length + if (readableBytes < 4) + { + in.resetReaderIndex(); + return false; + } + int valueLength = in.readInt(); + // check if we read that many bytes for the value + if (valueLength > readableBytes) + { + in.resetReaderIndex(); + return false; + } + + in.resetReaderIndex(); + return true; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + { + if (cause instanceof EOFException) + logger.trace("eof reading from socket; closing", cause); + else if (cause instanceof UnknownTableException) + logger.warn("Got message from unknown table while reading from socket; closing", cause); + else if (cause instanceof IOException) + logger.trace("IOException reading from socket; closing", cause); + else + logger.warn("Unexpected exception caught in inbound channel pipeline from " + ctx.channel().remoteAddress(), cause); + + ctx.close(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception + { + logger.debug("received channel closed message for peer {} on local addr {}", ctx.channel().remoteAddress(), ctx.channel().localAddress()); + ctx.fireChannelInactive(); + } + + // should ony be used for testing!!! + @VisibleForTesting + MessageHeader getMessageHeader() + { + return messageHeader; + } + + /** + * A simple struct to hold the message header data as it is being built up. + */ + static class MessageHeader + { + int messageId; + long constructionTime; + InetAddress from; + MessagingService.Verb verb; + int payloadSize; + + Map<String, byte[]> parameters = Collections.emptyMap(); + + /** + * Total number of incoming parameters. + */ + int parameterCount; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/src/java/org/apache/cassandra/net/async/MessageOutHandler.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/MessageOutHandler.java b/src/java/org/apache/cassandra/net/async/MessageOutHandler.java new file mode 100644 index 0000000..b4ceb92 --- /dev/null +++ b/src/java/org/apache/cassandra/net/async/MessageOutHandler.java @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.UnsupportedMessageTypeException; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; + +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.tracing.TraceState; +import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.utils.NanoTimeToCurrentTimeMillis; +import org.apache.cassandra.utils.NoSpamLogger; +import org.apache.cassandra.utils.UUIDGen; + +import static org.apache.cassandra.config.Config.PROPERTY_PREFIX; + +/** + * A Netty {@link ChannelHandler} for serializing outbound messages. + * <p> + * On top of transforming a {@link QueuedMessage} into bytes, this handler also feeds back progress to the linked + * {@link ChannelWriter} so that the latter can take decision on when data should be flushed (with and without coalescing). + * See the javadoc on {@link ChannelWriter} for more details about the callbacks as well as message timeouts. + *<p> + * Note: this class derives from {@link ChannelDuplexHandler} so we can intercept calls to + * {@link #userEventTriggered(ChannelHandlerContext, Object)} and {@link #channelWritabilityChanged(ChannelHandlerContext)}. + */ +class MessageOutHandler extends ChannelDuplexHandler +{ + private static final Logger logger = LoggerFactory.getLogger(MessageOutHandler.class); + private static final NoSpamLogger errorLogger = NoSpamLogger.getLogger(logger, 1, TimeUnit.SECONDS); + + /** + * The default size threshold for deciding when to auto-flush the channel. + */ + private static final int DEFAULT_AUTO_FLUSH_THRESHOLD = 1 << 16; + + // reatining the pre 4.0 property name for backward compatibility. + private static final String AUTO_FLUSH_PROPERTY = PROPERTY_PREFIX + "otc_buffer_size"; + static final int AUTO_FLUSH_THRESHOLD = Integer.getInteger(AUTO_FLUSH_PROPERTY, DEFAULT_AUTO_FLUSH_THRESHOLD); + + /** + * The amount of prefix data, in bytes, before the serialized message. + */ + private static final int MESSAGE_PREFIX_SIZE = 12; + + private final OutboundConnectionIdentifier connectionId; + + /** + * The version of the messaging protocol we're communicating at. + */ + private final int targetMessagingVersion; + + /** + * The minumum size at which we'll automatically flush the channel. + */ + private final int flushSizeThreshold; + + private final ChannelWriter channelWriter; + + private final Supplier<QueuedMessage> backlogSupplier; + + MessageOutHandler(OutboundConnectionIdentifier connectionId, int targetMessagingVersion, ChannelWriter channelWriter, Supplier<QueuedMessage> backlogSupplier) + { + this (connectionId, targetMessagingVersion, channelWriter, backlogSupplier, AUTO_FLUSH_THRESHOLD); + } + + MessageOutHandler(OutboundConnectionIdentifier connectionId, int targetMessagingVersion, ChannelWriter channelWriter, Supplier<QueuedMessage> backlogSupplier, int flushThreshold) + { + this.connectionId = connectionId; + this.targetMessagingVersion = targetMessagingVersion; + this.channelWriter = channelWriter; + this.flushSizeThreshold = flushThreshold; + this.backlogSupplier = backlogSupplier; + } + + @Override + public void write(ChannelHandlerContext ctx, Object o, ChannelPromise promise) + { + // this is a temporary fix until https://github.com/netty/netty/pull/6867 is released (probably netty 4.1.13). + // TL;DR a closed channel can still process messages in the pipeline that were queued before the close. + // the channel handlers are removed from the channel potentially saync from the close operation. + if (!ctx.channel().isOpen()) + { + logger.debug("attempting to process a message in the pipeline, but the channel is closed", ctx.channel().id()); + return; + } + + ByteBuf out = null; + try + { + if (!isMessageValid(o, promise)) + return; + + QueuedMessage msg = (QueuedMessage) o; + + // frame size includes the magic and and other values *before* the actual serialized message. + // note: don't even bother to check the compressed size (if compression is enabled for the channel), + // cuz if it's this large already, we're probably screwed anyway + long currentFrameSize = MESSAGE_PREFIX_SIZE + msg.message.serializedSize(targetMessagingVersion); + if (currentFrameSize > Integer.MAX_VALUE || currentFrameSize < 0) + { + promise.tryFailure(new IllegalStateException(String.format("%s illegal frame size: %d, ignoring message", connectionId, currentFrameSize))); + return; + } + + out = ctx.alloc().ioBuffer((int)currentFrameSize); + + captureTracingInfo(msg); + serializeMessage(msg, out); + ctx.write(out, promise); + + // check to see if we should flush based on buffered size + ChannelOutboundBuffer outboundBuffer = ctx.channel().unsafe().outboundBuffer(); + if (outboundBuffer != null && outboundBuffer.totalPendingWriteBytes() >= flushSizeThreshold) + ctx.flush(); + } + catch(Exception e) + { + if (out != null && out.refCnt() > 0) + out.release(out.refCnt()); + exceptionCaught(ctx, e); + promise.tryFailure(e); + } + finally + { + // Make sure we signal the outChanel even in case of errors. + channelWriter.onMessageProcessed(ctx); + } + } + + /** + * Test to see if the message passed in is a {@link QueuedMessage} and if it has timed out or not. If the checks fail, + * this method has the side effect of modifying the {@link ChannelPromise}. + */ + boolean isMessageValid(Object o, ChannelPromise promise) + { + // optimize for the common case + if (o instanceof QueuedMessage) + { + if (!((QueuedMessage)o).isTimedOut()) + { + return true; + } + else + { + promise.tryFailure(ExpiredException.INSTANCE); + } + } + else + { + promise.tryFailure(new UnsupportedMessageTypeException(connectionId + + " msg must be an instance of " + QueuedMessage.class.getSimpleName())); + } + return false; + } + + /** + * Record any tracing data, if enabled on this message. + */ + @VisibleForTesting + void captureTracingInfo(QueuedMessage msg) + { + try + { + byte[] sessionBytes = msg.message.parameters.get(Tracing.TRACE_HEADER); + if (sessionBytes != null) + { + UUID sessionId = UUIDGen.getUUID(ByteBuffer.wrap(sessionBytes)); + TraceState state = Tracing.instance.get(sessionId); + String message = String.format("Sending %s message to %s, size = %d bytes", + msg.message.verb, connectionId.connectionAddress(), + msg.message.serializedSize(targetMessagingVersion) + MESSAGE_PREFIX_SIZE); + // session may have already finished; see CASSANDRA-5668 + if (state == null) + { + byte[] traceTypeBytes = msg.message.parameters.get(Tracing.TRACE_TYPE); + Tracing.TraceType traceType = traceTypeBytes == null ? Tracing.TraceType.QUERY : Tracing.TraceType.deserialize(traceTypeBytes[0]); + Tracing.instance.trace(ByteBuffer.wrap(sessionBytes), message, traceType.getTTL()); + } + else + { + state.trace(message); + if (msg.message.verb == MessagingService.Verb.REQUEST_RESPONSE) + Tracing.instance.doneWithNonLocalSession(state); + } + } + } + catch (Exception e) + { + logger.warn("{} failed to capture the tracing info for an outbound message, ignoring", connectionId, e); + } + } + + private void serializeMessage(QueuedMessage msg, ByteBuf out) throws IOException + { + out.writeInt(MessagingService.PROTOCOL_MAGIC); + out.writeInt(msg.id); + + // int cast cuts off the high-order half of the timestamp, which we can assume remains + // the same between now and when the recipient reconstructs it. + out.writeInt((int) NanoTimeToCurrentTimeMillis.convert(msg.timestampNanos)); + @SuppressWarnings("resource") + DataOutputPlus outStream = new ByteBufDataOutputPlus(out); + msg.message.serialize(outStream, targetMessagingVersion); + + // next few lines are for debugging ... massively helpful!! + // if we allocated too much buffer for this message, we'll log here. + // if we allocated to little buffer space, we would have hit an exception when trying to write more bytes to it + if (out.isWritable()) + errorLogger.error("{} reported message size {}, actual message size {}, msg {}", + connectionId, out.capacity(), out.writerIndex(), msg.message); + } + + @Override + public void flush(ChannelHandlerContext ctx) + { + channelWriter.onTriggeredFlush(ctx); + } + + + /** + * {@inheritDoc} + * + * When the channel becomes writable (assuming it was previously unwritable), try to eat through any backlogged messages + * {@link #backlogSupplier}. As we're on the event loop when this is invoked, no one else can fill up the netty + * {@link ChannelOutboundBuffer}, so we should be able to make decent progress chewing through the backlog + * (assuming not large messages). Any messages messages written from {@link OutboundMessagingConnection} threads won't + * be processed immediately; they'll be queued up as tasks, and once this function return, those messages can begin + * to be consumed. + * <p> + * Note: this is invoked on the netty event loop. + */ + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) + { + if (!ctx.channel().isWritable()) + return; + + // guarantee at least a minimal amount of progress (one messge from the backlog) by using a do-while loop. + do + { + QueuedMessage msg = backlogSupplier.get(); + if (msg == null || !channelWriter.write(msg, false)) + break; + } while (ctx.channel().isWritable()); + } + + /** + * {@inheritDoc} + * + * If we get an {@link IdleStateEvent} for the write path, we want to close the channel as we can't make progress. + * That assumes, of course, that there's any outstanding bytes in the channel to write. We don't necesarrily care + * about idleness (for example, gossip channels will be idle most of the time), but instead our concern is + * the ability to make progress when there's work to be done. + * <p> + * Note: this is invoked on the netty event loop. + */ + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) + { + if (evt instanceof IdleStateEvent && ((IdleStateEvent)evt).state() == IdleState.WRITER_IDLE) + { + ChannelOutboundBuffer cob = ctx.channel().unsafe().outboundBuffer(); + if (cob != null && cob.totalPendingWriteBytes() > 0) + { + ctx.channel().attr(ChannelWriter.PURGE_MESSAGES_CHANNEL_ATTR) + .compareAndSet(Boolean.FALSE, Boolean.TRUE); + ctx.close(); + } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + { + if (cause instanceof IOException) + logger.trace("{} io error", connectionId, cause); + else + logger.warn("{} error", connectionId, cause); + + ctx.close(); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) + { + ctx.flush(); + ctx.close(promise); + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/src/java/org/apache/cassandra/net/async/MessageResult.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/net/async/MessageResult.java b/src/java/org/apache/cassandra/net/async/MessageResult.java new file mode 100644 index 0000000..b0dc4dc --- /dev/null +++ b/src/java/org/apache/cassandra/net/async/MessageResult.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import io.netty.util.concurrent.Future; + +/** + * A simple, reusable struct that holds the unprocessed result of sending a message via netty. This object is intended + * to be reusable to avoid creating a bunch of garbage (just for processing the results of sending a message). + * + * The intended use is to be a member field in a class, like {@link ChannelWriter}, repopulated on each message result, + * and then immediately cleared (via {@link #clearAll()}) when done. + */ +public class MessageResult +{ + ChannelWriter writer; + QueuedMessage msg; + Future<? super Void> future; + boolean allowReconnect; + + void setAll(ChannelWriter writer, QueuedMessage msg, Future<? super Void> future, boolean allowReconnect) + { + this.writer = writer; + this.msg = msg; + this.future = future; + this.allowReconnect = allowReconnect; + } + + void clearAll() + { + this.writer = null; + this.msg = null; + this.future = null; + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org