http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java b/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java new file mode 100644 index 0000000..128fe4b --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Optional; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.embedded.EmbeddedChannel; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.async.ChannelWriter.CoalescingChannelWriter; +import org.apache.cassandra.utils.CoalescingStrategies; +import org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy; + +import static org.apache.cassandra.net.MessagingService.Verb.ECHO; + +/** + * with the write_Coalescing_* methods, if there's data in the channel.unsafe().outboundBuffer() + * it means that there's something in the channel that hasn't yet been flushed to the transport (socket). + * once a flush occurs, there will be an entry in EmbeddedChannel's outboundQueue. those two facts are leveraged in these tests. + */ +public class ChannelWriterTest +{ + private static final int COALESCE_WINDOW_MS = 10; + + private EmbeddedChannel channel; + private ChannelWriter channelWriter; + private NonSendingOutboundMessagingConnection omc; + private Optional<CoalescingStrategy> coalescingStrategy; + + @BeforeClass + public static void before() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setup() + { + OutboundConnectionIdentifier id = OutboundConnectionIdentifier.small(new InetSocketAddress("127.0.0.1", 0), + new InetSocketAddress("127.0.0.2", 0)); + channel = new EmbeddedChannel(); + omc = new NonSendingOutboundMessagingConnection(id, null, Optional.empty()); + channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()); + channel.pipeline().addFirst(new MessageOutHandler(id, MessagingService.current_version, channelWriter, () -> null)); + coalescingStrategy = CoalescingStrategies.newCoalescingStrategy(CoalescingStrategies.Strategy.FIXED.name(), COALESCE_WINDOW_MS, null, "test"); + } + + @Test + public void create_nonCoalescing() + { + Assert.assertSame(ChannelWriter.SimpleChannelWriter.class, ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()).getClass()); + } + + @Test + public void create_Coalescing() + { + Assert.assertSame(CoalescingChannelWriter.class, ChannelWriter.create(channel, omc::handleMessageResult, coalescingStrategy).getClass()); + } + + @Test + public void write_IsWritable() + { + Assert.assertTrue(channel.isWritable()); + Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); + Assert.assertTrue(channel.isWritable()); + Assert.assertTrue(channel.releaseOutbound()); + } + + @Test + public void write_NotWritable() + { + channel.config().setOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1, 2)); + + // send one message through, which will trigger the writability check (and turn it off) + Assert.assertTrue(channel.isWritable()); + ByteBuf buf = channel.alloc().buffer(8, 8); + channel.unsafe().outboundBuffer().addMessage(buf, buf.capacity(), channel.newPromise()); + Assert.assertFalse(channel.isWritable()); + Assert.assertFalse(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); + Assert.assertFalse(channel.isWritable()); + Assert.assertFalse(channel.releaseOutbound()); + buf.release(); + } + + @Test + public void write_NotWritableButWriteAnyway() + { + channel.config().setOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1, 2)); + + // send one message through, which will trigger the writability check (and turn it off) + Assert.assertTrue(channel.isWritable()); + ByteBuf buf = channel.alloc().buffer(8, 8); + channel.unsafe().outboundBuffer().addMessage(buf, buf.capacity(), channel.newPromise()); + Assert.assertFalse(channel.isWritable()); + Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), false)); + Assert.assertTrue(channel.isWritable()); + Assert.assertTrue(channel.releaseOutbound()); + } + + @Test + public void write_Coalescing_LostRaceForFlushTask() + { + CoalescingChannelWriter channelWriter = resetEnvForCoalescing(DatabaseDescriptor.getOtcCoalescingEnoughCoalescedMessages()); + channelWriter.scheduledFlush.set(true); + Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); + Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); + Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() > 0); + Assert.assertFalse(channel.releaseOutbound()); + Assert.assertTrue(channelWriter.scheduledFlush.get()); + } + + @Test + public void write_Coalescing_HitMinMessageCountForImmediateCoalesce() + { + CoalescingChannelWriter channelWriter = resetEnvForCoalescing(1); + + Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); + Assert.assertFalse(channelWriter.scheduledFlush.get()); + Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); + + Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); + Assert.assertTrue(channel.releaseOutbound()); + Assert.assertFalse(channelWriter.scheduledFlush.get()); + } + + @Test + public void write_Coalescing_ScheduleFlushTask() + { + CoalescingChannelWriter channelWriter = resetEnvForCoalescing(DatabaseDescriptor.getOtcCoalescingEnoughCoalescedMessages()); + + Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); + Assert.assertFalse(channelWriter.scheduledFlush.get()); + Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true)); + + Assert.assertTrue(channelWriter.scheduledFlush.get()); + Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() > 0); + Assert.assertTrue(channelWriter.scheduledFlush.get()); + + // this unfortunately know a little too much about how the sausage is made in CoalescingChannelWriter :-/ + channel.runScheduledPendingTasks(); + channel.runPendingTasks(); + Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0); + Assert.assertFalse(channelWriter.scheduledFlush.get()); + Assert.assertTrue(channel.releaseOutbound()); + } + + private CoalescingChannelWriter resetEnvForCoalescing(int minMessagesForCoalesce) + { + channel = new EmbeddedChannel(); + CoalescingChannelWriter cw = new CoalescingChannelWriter(channel, omc::handleMessageResult, coalescingStrategy.get(), minMessagesForCoalesce); + channel.pipeline().addFirst(new ChannelOutboundHandlerAdapter() + { + public void flush(ChannelHandlerContext ctx) throws Exception + { + cw.onTriggeredFlush(ctx); + } + }); + omc.setChannelWriter(cw); + return cw; + } + + @Test + public void writeBacklog_Empty() + { + BlockingQueue<QueuedMessage> queue = new LinkedBlockingQueue<>(); + Assert.assertEquals(0, channelWriter.writeBacklog(queue, false)); + Assert.assertFalse(channel.releaseOutbound()); + } + + @Test + public void writeBacklog_ChannelNotWritable() + { + Assert.assertTrue(channel.isWritable()); + // force the channel to be non writable + channel.config().setOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1, 2)); + ByteBuf buf = channel.alloc().buffer(8, 8); + channel.unsafe().outboundBuffer().addMessage(buf, buf.capacity(), channel.newPromise()); + Assert.assertFalse(channel.isWritable()); + + Assert.assertEquals(0, channelWriter.writeBacklog(new LinkedBlockingQueue<>(), false)); + Assert.assertFalse(channel.releaseOutbound()); + Assert.assertFalse(channel.isWritable()); + buf.release(); + } + + @Test + public void writeBacklog_NotEmpty() + { + BlockingQueue<QueuedMessage> queue = new LinkedBlockingQueue<>(); + int count = 12; + for (int i = 0; i < count; i++) + queue.offer(new QueuedMessage(new MessageOut<>(ECHO), i)); + Assert.assertEquals(count, channelWriter.writeBacklog(queue, false)); + Assert.assertTrue(channel.releaseOutbound()); + } + + @Test + public void close() + { + Assert.assertFalse(channelWriter.isClosed()); + Assert.assertTrue(channel.isOpen()); + channelWriter.close(); + Assert.assertFalse(channel.isOpen()); + Assert.assertTrue(channelWriter.isClosed()); + } + + @Test + public void softClose() + { + Assert.assertFalse(channelWriter.isClosed()); + Assert.assertTrue(channel.isOpen()); + channelWriter.softClose(); + Assert.assertFalse(channel.isOpen()); + Assert.assertTrue(channelWriter.isClosed()); + } + + @Test + public void handleMessagePromise_FutureIsCancelled() + { + ChannelPromise promise = channel.newPromise(); + promise.cancel(false); + channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1), true); + Assert.assertTrue(channel.isActive()); + Assert.assertEquals(1, omc.getCompletedMessages().longValue()); + Assert.assertEquals(0, omc.getDroppedMessages().longValue()); + } + + @Test + public void handleMessagePromise_ExpiredException_DoNotRetryMsg() + { + ChannelPromise promise = channel.newPromise(); + promise.setFailure(new ExpiredException()); + + channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1), true); + Assert.assertTrue(channel.isActive()); + Assert.assertEquals(1, omc.getCompletedMessages().longValue()); + Assert.assertEquals(1, omc.getDroppedMessages().longValue()); + Assert.assertFalse(omc.sendMessageInvoked); + } + + @Test + public void handleMessagePromise_NonIOException() + { + ChannelPromise promise = channel.newPromise(); + promise.setFailure(new NullPointerException("this is a test")); + channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1), true); + Assert.assertTrue(channel.isActive()); + Assert.assertEquals(1, omc.getCompletedMessages().longValue()); + Assert.assertEquals(0, omc.getDroppedMessages().longValue()); + Assert.assertFalse(omc.sendMessageInvoked); + } + + @Test + public void handleMessagePromise_IOException_ChannelNotClosed_RetryMsg() + { + ChannelPromise promise = channel.newPromise(); + promise.setFailure(new IOException("this is a test")); + Assert.assertTrue(channel.isActive()); + channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1, 0, true, true), true); + + Assert.assertFalse(channel.isActive()); + Assert.assertEquals(1, omc.getCompletedMessages().longValue()); + Assert.assertEquals(0, omc.getDroppedMessages().longValue()); + Assert.assertTrue(omc.sendMessageInvoked); + } + + @Test + public void handleMessagePromise_Cancelled() + { + ChannelPromise promise = channel.newPromise(); + promise.cancel(false); + Assert.assertTrue(channel.isActive()); + channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1, 0, true, true), true); + + Assert.assertTrue(channel.isActive()); + Assert.assertEquals(1, omc.getCompletedMessages().longValue()); + Assert.assertEquals(0, omc.getDroppedMessages().longValue()); + Assert.assertFalse(omc.sendMessageInvoked); + } +}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java b/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java new file mode 100644 index 0000000..fa6e2b5 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.Optional; + +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.auth.AllowAllInternodeAuthenticator; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.Mutation; +import org.apache.cassandra.db.RowUpdateBuilder; +import org.apache.cassandra.db.compaction.CompactionManager; +import org.apache.cassandra.db.marshal.AsciiType; +import org.apache.cassandra.db.marshal.BytesType; +import org.apache.cassandra.exceptions.ConfigurationException; +import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.schema.KeyspaceParams; + +import static org.apache.cassandra.net.async.InboundHandshakeHandler.State.MESSAGING_HANDSHAKE_COMPLETE; +import static org.apache.cassandra.net.async.OutboundMessagingConnection.State.READY; + +public class HandshakeHandlersTest +{ + private static final String KEYSPACE1 = "NettyPipilineTest"; + private static final String STANDARD1 = "Standard1"; + + private static final InetSocketAddress LOCAL_ADDR = new InetSocketAddress("127.0.0.1", 9999); + private static final InetSocketAddress REMOTE_ADDR = new InetSocketAddress("127.0.0.2", 9999); + private static final int MESSAGING_VERSION = MessagingService.current_version; + private static final OutboundConnectionIdentifier connectionId = OutboundConnectionIdentifier.small(LOCAL_ADDR, REMOTE_ADDR); + + @BeforeClass + public static void beforeClass() throws ConfigurationException + { + SchemaLoader.prepareServer(); + SchemaLoader.createKeyspace(KEYSPACE1, + KeyspaceParams.simple(1), + SchemaLoader.standardCFMD(KEYSPACE1, STANDARD1, 0, AsciiType.instance, BytesType.instance)); + CompactionManager.instance.disableAutoCompaction(); + } + + @Test + public void handshake_HappyPath() + { + // beacuse both CHH & SHH are ChannelInboundHandlers, we can't use the same EmbeddedChannel to handle them + InboundHandshakeHandler inboundHandshakeHandler = new InboundHandshakeHandler(new TestAuthenticator(true)); + EmbeddedChannel inboundChannel = new EmbeddedChannel(inboundHandshakeHandler); + + OutboundMessagingConnection imc = new OutboundMessagingConnection(connectionId, null, Optional.empty(), new AllowAllInternodeAuthenticator()); + OutboundConnectionParams params = OutboundConnectionParams.builder() + .connectionId(connectionId) + .callback(imc::finishHandshake) + .mode(NettyFactory.Mode.MESSAGING) + .protocolVersion(MessagingService.current_version) + .coalescingStrategy(Optional.empty()) + .build(); + OutboundHandshakeHandler outboundHandshakeHandler = new OutboundHandshakeHandler(params); + EmbeddedChannel outboundChannel = new EmbeddedChannel(outboundHandshakeHandler); + Assert.assertEquals(1, outboundChannel.outboundMessages().size()); + + // move internode protocol Msg1 to the server's channel + Object o; + while ((o = outboundChannel.readOutbound()) != null) + inboundChannel.writeInbound(o); + Assert.assertEquals(1, inboundChannel.outboundMessages().size()); + + // move internode protocol Msg2 to the client's channel + while ((o = inboundChannel.readOutbound()) != null) + outboundChannel.writeInbound(o); + Assert.assertEquals(1, outboundChannel.outboundMessages().size()); + + // move internode protocol Msg3 to the server's channel + while ((o = outboundChannel.readOutbound()) != null) + inboundChannel.writeInbound(o); + + Assert.assertEquals(READY, imc.getState()); + Assert.assertEquals(MESSAGING_HANDSHAKE_COMPLETE, inboundHandshakeHandler.getState()); + } + + @Test + public void lotsOfMutations_NoCompression() throws IOException + { + lotsOfMutations(false); + } + + @Test + public void lotsOfMutations_WithCompression() throws IOException + { + lotsOfMutations(true); + } + + private void lotsOfMutations(boolean compress) + { + TestChannels channels = buildChannels(compress); + EmbeddedChannel outboundChannel = channels.outboundChannel; + EmbeddedChannel inboundChannel = channels.inboundChannel; + + // now the actual test! + ByteBuffer buf = ByteBuffer.allocate(1 << 10); + byte[] bytes = "ThisIsA16CharStr".getBytes(); + while (buf.remaining() > 0) + buf.put(bytes); + + // write a bunch of messages to the channel + ColumnFamilyStore cfs1 = Keyspace.open(KEYSPACE1).getColumnFamilyStore(STANDARD1); + int count = 1024; + for (int i = 0; i < count; i++) + { + if (i % 2 == 0) + { + Mutation mutation = new RowUpdateBuilder(cfs1.metadata.get(), 0, "k") + .clustering("bytes") + .add("val", buf) + .build(); + + QueuedMessage msg = new QueuedMessage(mutation.createMessage(), i); + outboundChannel.writeAndFlush(msg); + } + else + { + outboundChannel.writeAndFlush(new QueuedMessage(new MessageOut<>(MessagingService.Verb.ECHO), i)); + } + } + outboundChannel.flush(); + + // move the messages to the other channel + Object o; + while ((o = outboundChannel.readOutbound()) != null) + inboundChannel.writeInbound(o); + + Assert.assertTrue(outboundChannel.outboundMessages().isEmpty()); + Assert.assertFalse(inboundChannel.finishAndReleaseAll()); + } + + private TestChannels buildChannels(boolean compress) + { + OutboundConnectionParams params = OutboundConnectionParams.builder() + .connectionId(connectionId) + .callback(this::nop) + .mode(NettyFactory.Mode.MESSAGING) + .compress(compress) + .coalescingStrategy(Optional.empty()) + .protocolVersion(MessagingService.current_version) + .build(); + OutboundHandshakeHandler outboundHandshakeHandler = new OutboundHandshakeHandler(params); + EmbeddedChannel outboundChannel = new EmbeddedChannel(outboundHandshakeHandler); + OutboundMessagingConnection omc = new OutboundMessagingConnection(connectionId, null, Optional.empty(), new AllowAllInternodeAuthenticator()); + omc.setTargetVersion(MESSAGING_VERSION); + outboundHandshakeHandler.setupPipeline(outboundChannel, MESSAGING_VERSION); + + // remove the outbound handshake message from the outbound messages + outboundChannel.outboundMessages().clear(); + + InboundHandshakeHandler handler = new InboundHandshakeHandler(new TestAuthenticator(true)); + EmbeddedChannel inboundChannel = new EmbeddedChannel(handler); + handler.setupMessagingPipeline(inboundChannel.pipeline(), REMOTE_ADDR.getAddress(), compress, MESSAGING_VERSION); + + return new TestChannels(outboundChannel, inboundChannel); + } + + private static class TestChannels + { + final EmbeddedChannel outboundChannel; + final EmbeddedChannel inboundChannel; + + TestChannels(EmbeddedChannel outboundChannel, EmbeddedChannel inboundChannel) + { + this.outboundChannel = outboundChannel; + this.inboundChannel = inboundChannel; + } + } + + private Void nop(OutboundHandshakeHandler.HandshakeResult handshakeResult) + { + // do nothing, really + return null; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java b/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java new file mode 100644 index 0000000..a3d646d --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.net.async; + +import org.junit.After; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.async.HandshakeProtocol.FirstHandshakeMessage; +import org.apache.cassandra.net.async.HandshakeProtocol.SecondHandshakeMessage; +import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage; +import org.apache.cassandra.utils.FBUtilities; + +import static org.junit.Assert.assertEquals; + +public class HandshakeProtocolTest +{ + private ByteBuf buf; + + @BeforeClass + public static void before() + { + // Kind of stupid, but the test trigger the initialization of the MessagingService class and that require + // DatabaseDescriptor to be configured ... + DatabaseDescriptor.daemonInitialization(); + } + + @After + public void tearDown() + { + if (buf != null && buf.refCnt() > 0) + buf.release(); + } + + @Test + public void firstMessageTest() throws Exception + { + firstMessageTest(NettyFactory.Mode.MESSAGING, false); + firstMessageTest(NettyFactory.Mode.MESSAGING, true); + firstMessageTest(NettyFactory.Mode.STREAMING, false); + firstMessageTest(NettyFactory.Mode.STREAMING, true); + } + + private void firstMessageTest(NettyFactory.Mode mode, boolean compression) throws Exception + { + FirstHandshakeMessage before = new FirstHandshakeMessage(MessagingService.current_version, mode, compression); + buf = before.encode(PooledByteBufAllocator.DEFAULT); + FirstHandshakeMessage after = FirstHandshakeMessage.maybeDecode(buf); + assertEquals(before, after); + assertEquals(before.hashCode(), after.hashCode()); + Assert.assertFalse(before.equals(null)); + } + + @Test + public void secondMessageTest() throws Exception + { + SecondHandshakeMessage before = new SecondHandshakeMessage(MessagingService.current_version); + buf = before.encode(PooledByteBufAllocator.DEFAULT); + SecondHandshakeMessage after = SecondHandshakeMessage.maybeDecode(buf); + assertEquals(before, after); + assertEquals(before.hashCode(), after.hashCode()); + Assert.assertFalse(before.equals(null)); + } + + @Test + public void thirdMessageTest() throws Exception + { + ThirdHandshakeMessage before = new ThirdHandshakeMessage(MessagingService.current_version, FBUtilities.getBroadcastAddress()); + buf = before.encode(PooledByteBufAllocator.DEFAULT); + ThirdHandshakeMessage after = ThirdHandshakeMessage.maybeDecode(buf); + assertEquals(before, after); + assertEquals(before.hashCode(), after.hashCode()); + Assert.assertFalse(before.equals(null)); + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java b/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java new file mode 100644 index 0000000..44dc469 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayList; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.compression.Lz4FrameDecoder; +import io.netty.handler.codec.compression.Lz4FrameEncoder; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.net.CompactEndpointSerializationHelper; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.async.HandshakeProtocol.FirstHandshakeMessage; +import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage; +import org.apache.cassandra.net.async.InboundHandshakeHandler.State; + +import static org.apache.cassandra.net.async.NettyFactory.Mode.MESSAGING; + +public class InboundHandshakeHandlerTest +{ + private static final InetSocketAddress addr = new InetSocketAddress("127.0.0.1", 0); + private static final int MESSAGING_VERSION = MessagingService.current_version; + private static final int VERSION_30 = MessagingService.VERSION_30; + + private InboundHandshakeHandler handler; + private EmbeddedChannel channel; + private ByteBuf buf; + + @BeforeClass + public static void beforeClass() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setUp() + { + TestAuthenticator authenticator = new TestAuthenticator(false); + handler = new InboundHandshakeHandler(authenticator); + channel = new EmbeddedChannel(handler); + } + + @After + public void tearDown() + { + if (buf != null) + buf.release(); + channel.finishAndReleaseAll(); + } + + @Test + public void handleAuthenticate_Good() + { + handler = new InboundHandshakeHandler(new TestAuthenticator(true)); + channel = new EmbeddedChannel(handler); + boolean result = handler.handleAuthenticate(addr, channel.pipeline().firstContext()); + Assert.assertTrue(result); + Assert.assertTrue(channel.isOpen()); + } + + @Test + public void handleAuthenticate_Bad() + { + boolean result = handler.handleAuthenticate(addr, channel.pipeline().firstContext()); + Assert.assertFalse(result); + Assert.assertFalse(channel.isOpen()); + Assert.assertFalse(channel.isActive()); + } + + @Test + public void handleAuthenticate_BadSocketAddr() + { + boolean result = handler.handleAuthenticate(new FakeSocketAddress(), channel.pipeline().firstContext()); + Assert.assertFalse(result); + Assert.assertFalse(channel.isOpen()); + Assert.assertFalse(channel.isActive()); + } + + private static class FakeSocketAddress extends SocketAddress + { } + + @Test + public void decode_AlreadyFailed() + { + handler.setState(State.HANDSHAKE_FAIL); + buf = new FirstHandshakeMessage(MESSAGING_VERSION, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT); + handler.decode(channel.pipeline().firstContext(), buf, new ArrayList<>()); + Assert.assertFalse(channel.isOpen()); + Assert.assertFalse(channel.isActive()); + Assert.assertSame(State.HANDSHAKE_FAIL, handler.getState()); + } + + @Test + public void handleStart_NotEnoughInputBytes() throws IOException + { + ByteBuf buf = Unpooled.EMPTY_BUFFER; + State state = handler.handleStart(channel.pipeline().firstContext(), buf); + Assert.assertEquals(State.START, state); + Assert.assertTrue(channel.isOpen()); + Assert.assertTrue(channel.isActive()); + } + + @Test (expected = IOException.class) + public void handleStart_BadMagic() throws IOException + { + InboundHandshakeHandler handler = new InboundHandshakeHandler(new TestAuthenticator(false)); + EmbeddedChannel channel = new EmbeddedChannel(handler); + buf = Unpooled.buffer(32, 32); + + FirstHandshakeMessage first = new FirstHandshakeMessage(MESSAGING_VERSION, + MESSAGING, + true); + + buf.writeInt(MessagingService.PROTOCOL_MAGIC << 2); + buf.writeInt(first.encodeFlags()); + handler.handleStart(channel.pipeline().firstContext(), buf); + } + + @Test + public void handleStart_VersionTooHigh() throws IOException + { + channel.eventLoop(); + buf = new FirstHandshakeMessage(MESSAGING_VERSION + 1, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT); + State state = handler.handleStart(channel.pipeline().firstContext(), buf); + Assert.assertEquals(State.HANDSHAKE_FAIL, state); + Assert.assertFalse(channel.isOpen()); + Assert.assertFalse(channel.isActive()); + } + + @Test + public void handleStart_VersionLessThan3_0() throws IOException + { + buf = new FirstHandshakeMessage(VERSION_30 - 1, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT); + State state = handler.handleStart(channel.pipeline().firstContext(), buf); + Assert.assertEquals(State.HANDSHAKE_FAIL, state); + + Assert.assertFalse(channel.isOpen()); + Assert.assertFalse(channel.isActive()); + } + + @Test + public void handleStart_HappyPath_Messaging() throws IOException + { + buf = new FirstHandshakeMessage(MESSAGING_VERSION, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT); + State state = handler.handleStart(channel.pipeline().firstContext(), buf); + Assert.assertEquals(State.AWAIT_MESSAGING_START_RESPONSE, state); + if (buf.refCnt() > 0) + buf.release(); + + buf = new ThirdHandshakeMessage(MESSAGING_VERSION, addr.getAddress()).encode(PooledByteBufAllocator.DEFAULT); + state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf); + + Assert.assertEquals(State.MESSAGING_HANDSHAKE_COMPLETE, state); + Assert.assertTrue(channel.isOpen()); + Assert.assertTrue(channel.isActive()); + Assert.assertFalse(channel.outboundMessages().isEmpty()); + channel.releaseOutbound(); + } + + @Test + public void handleMessagingStartResponse_NotEnoughInputBytes() throws IOException + { + ByteBuf buf = Unpooled.EMPTY_BUFFER; + State state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf); + Assert.assertEquals(State.AWAIT_MESSAGING_START_RESPONSE, state); + Assert.assertTrue(channel.isOpen()); + Assert.assertTrue(channel.isActive()); + } + + @Test + public void handleMessagingStartResponse_BadMaxVersion() throws IOException + { + buf = Unpooled.buffer(32, 32); + buf.writeInt(MESSAGING_VERSION + 1); + CompactEndpointSerializationHelper.serialize(addr.getAddress(), new ByteBufOutputStream(buf)); + State state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf); + Assert.assertEquals(State.HANDSHAKE_FAIL, state); + Assert.assertFalse(channel.isOpen()); + Assert.assertFalse(channel.isActive()); + } + + @Test + public void handleMessagingStartResponse_HappyPath() throws IOException + { + buf = Unpooled.buffer(32, 32); + buf.writeInt(MESSAGING_VERSION); + CompactEndpointSerializationHelper.serialize(addr.getAddress(), new ByteBufOutputStream(buf)); + State state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf); + Assert.assertEquals(State.MESSAGING_HANDSHAKE_COMPLETE, state); + Assert.assertTrue(channel.isOpen()); + Assert.assertTrue(channel.isActive()); + } + + @Test + public void setupPipeline_NoCompression() + { + ChannelPipeline pipeline = channel.pipeline(); + Assert.assertNotNull(pipeline.get(InboundHandshakeHandler.class)); + + handler.setupMessagingPipeline(pipeline, addr.getAddress(), false, MESSAGING_VERSION); + Assert.assertNotNull(pipeline.get(MessageInHandler.class)); + Assert.assertNull(pipeline.get(Lz4FrameDecoder.class)); + Assert.assertNull(pipeline.get(Lz4FrameEncoder.class)); + Assert.assertNull(pipeline.get(InboundHandshakeHandler.class)); + } + + @Test + public void setupPipeline_WithCompression() + { + ChannelPipeline pipeline = channel.pipeline(); + Assert.assertNotNull(pipeline.get(InboundHandshakeHandler.class)); + + handler.setupMessagingPipeline(pipeline, addr.getAddress(), true, MESSAGING_VERSION); + Assert.assertNotNull(pipeline.get(MessageInHandler.class)); + Assert.assertNotNull(pipeline.get(Lz4FrameDecoder.class)); + Assert.assertNull(pipeline.get(Lz4FrameEncoder.class)); + Assert.assertNull(pipeline.get(InboundHandshakeHandler.class)); + } + + @Test + public void failHandshake() + { + ChannelPromise future = channel.newPromise(); + handler.setHandshakeTimeout(future); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(channel.isOpen()); + handler.failHandshake(channel.pipeline().firstContext()); + Assert.assertSame(State.HANDSHAKE_FAIL, handler.getState()); + Assert.assertTrue(future.isCancelled()); + Assert.assertFalse(channel.isOpen()); + } + + @Test + public void failHandshake_AlreadyConnected() + { + ChannelPromise future = channel.newPromise(); + handler.setHandshakeTimeout(future); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(channel.isOpen()); + handler.setState(State.MESSAGING_HANDSHAKE_COMPLETE); + handler.failHandshake(channel.pipeline().firstContext()); + Assert.assertSame(State.MESSAGING_HANDSHAKE_COMPLETE, handler.getState()); + Assert.assertTrue(channel.isOpen()); + } + + @Test + public void failHandshake_TaskIsCancelled() + { + ChannelPromise future = channel.newPromise(); + future.cancel(false); + handler.setHandshakeTimeout(future); + handler.setState(State.AWAIT_MESSAGING_START_RESPONSE); + Assert.assertTrue(channel.isOpen()); + handler.failHandshake(channel.pipeline().firstContext()); + Assert.assertSame(State.AWAIT_MESSAGING_START_RESPONSE, handler.getState()); + Assert.assertTrue(channel.isOpen()); + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java b/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java new file mode 100644 index 0000000..bb82d2c --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.EOFException; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import com.google.common.base.Charsets; +import org.junit.After; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.util.concurrent.Future; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.net.MessageIn; +import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.async.MessageInHandler.MessageHeader; +import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.NanoTimeToCurrentTimeMillis; + +public class MessageInHandlerTest +{ + private static final InetSocketAddress addr = new InetSocketAddress("127.0.0.1", 0); + private static final int MSG_VERSION = MessagingService.current_version; + + private static final int MSG_ID = 42; + + private ByteBuf buf; + + @BeforeClass + public static void before() + { + DatabaseDescriptor.daemonInitialization(); + } + + @After + public void tearDown() + { + if (buf != null && buf.refCnt() > 0) + buf.release(); + } + + @Test + public void decode_BadMagic() throws Exception + { + int len = MessageInHandler.FIRST_SECTION_BYTE_COUNT; + buf = Unpooled.buffer(len, len); + buf.writeInt(-1); + buf.writerIndex(len); + + MessageInHandler handler = new MessageInHandler(addr.getAddress(), MSG_VERSION, null); + EmbeddedChannel channel = new EmbeddedChannel(handler); + Assert.assertTrue(channel.isOpen()); + channel.writeInbound(buf); + Assert.assertFalse(channel.isOpen()); + } + + @Test + public void decode_HappyPath_NoParameters() throws Exception + { + MessageInWrapper result = decode_HappyPath(Collections.emptyMap()); + Assert.assertTrue(result.messageIn.parameters.isEmpty()); + } + + @Test + public void decode_HappyPath_WithParameters() throws Exception + { + Map<String, byte[]> parameters = new HashMap<>(); + parameters.put("p1", "val1".getBytes(Charsets.UTF_8)); + parameters.put("p2", "val2".getBytes(Charsets.UTF_8)); + MessageInWrapper result = decode_HappyPath(parameters); + Assert.assertEquals(2, result.messageIn.parameters.size()); + } + + private MessageInWrapper decode_HappyPath(Map<String, byte[]> parameters) throws Exception + { + MessageOut msgOut = new MessageOut(MessagingService.Verb.ECHO); + for (Map.Entry<String, byte[]> param : parameters.entrySet()) + msgOut = msgOut.withParameter(param.getKey(), param.getValue()); + serialize(msgOut); + + MessageInWrapper wrapper = new MessageInWrapper(); + MessageInHandler handler = new MessageInHandler(addr.getAddress(), MSG_VERSION, wrapper.messageConsumer); + List<Object> out = new ArrayList<>(); + handler.decode(null, buf, out); + + Assert.assertNotNull(wrapper.messageIn); + Assert.assertEquals(MSG_ID, wrapper.id); + Assert.assertEquals(msgOut.from, wrapper.messageIn.from); + Assert.assertEquals(msgOut.verb, wrapper.messageIn.verb); + Assert.assertTrue(out.isEmpty()); + + return wrapper; + } + + private void serialize(MessageOut msgOut) throws IOException + { + buf = Unpooled.buffer(1024, 1024); // 1k should be enough for everybody! + buf.writeInt(MessagingService.PROTOCOL_MAGIC); + buf.writeInt(MSG_ID); // this is the id + buf.writeInt((int) NanoTimeToCurrentTimeMillis.convert(System.nanoTime())); + + msgOut.serialize(new ByteBufDataOutputPlus(buf), MSG_VERSION); + } + + @Test + public void decode_WithHalfReceivedParameters() throws Exception + { + MessageOut msgOut = new MessageOut(MessagingService.Verb.ECHO); + msgOut = msgOut.withParameter("p3", "val1".getBytes(Charsets.UTF_8)); + + serialize(msgOut); + + // move the write index pointer back a few bytes to simulate like the full bytes are not present. + // yeah, it's lame, but it tests the basics of what is happening during the deserialiization + int originalWriterIndex = buf.writerIndex(); + buf.writerIndex(originalWriterIndex - 6); + + MessageInWrapper wrapper = new MessageInWrapper(); + MessageInHandler handler = new MessageInHandler(addr.getAddress(), MSG_VERSION, wrapper.messageConsumer); + List<Object> out = new ArrayList<>(); + handler.decode(null, buf, out); + + Assert.assertNull(wrapper.messageIn); + + MessageHeader header = handler.getMessageHeader(); + Assert.assertEquals(MSG_ID, header.messageId); + Assert.assertEquals(msgOut.verb, header.verb); + Assert.assertEquals(msgOut.from, header.from); + Assert.assertTrue(out.isEmpty()); + + // now, set the writer index back to the original value to pretend that we actually got more bytes in + buf.writerIndex(originalWriterIndex); + handler.decode(null, buf, out); + Assert.assertNotNull(wrapper.messageIn); + Assert.assertTrue(out.isEmpty()); + } + + @Test + public void canReadNextParam_HappyPath() throws IOException + { + buildParamBuf(13); + Assert.assertTrue(MessageInHandler.canReadNextParam(buf)); + } + + @Test + public void canReadNextParam_OnlyFirstByte() throws IOException + { + buildParamBuf(13); + buf.writerIndex(1); + Assert.assertFalse(MessageInHandler.canReadNextParam(buf)); + } + + @Test + public void canReadNextParam_PartialUTF() throws IOException + { + buildParamBuf(13); + buf.writerIndex(5); + Assert.assertFalse(MessageInHandler.canReadNextParam(buf)); + } + + @Test + public void canReadNextParam_TruncatedValueLength() throws IOException + { + buildParamBuf(13); + buf.writerIndex(buf.writerIndex() - 13 - 2); + Assert.assertFalse(MessageInHandler.canReadNextParam(buf)); + } + + @Test + public void canReadNextParam_MissingLastBytes() throws IOException + { + buildParamBuf(13); + buf.writerIndex(buf.writerIndex() - 2); + Assert.assertFalse(MessageInHandler.canReadNextParam(buf)); + } + + private void buildParamBuf(int valueLength) throws IOException + { + buf = Unpooled.buffer(1024, 1024); // 1k should be enough for everybody! + ByteBufDataOutputPlus output = new ByteBufDataOutputPlus(buf); + output.writeUTF("name"); + byte[] array = new byte[valueLength]; + output.writeInt(array.length); + output.write(array); + } + + @Test + public void exceptionHandled() + { + MessageInHandler handler = new MessageInHandler(addr.getAddress(), MSG_VERSION, null); + EmbeddedChannel channel = new EmbeddedChannel(handler); + Assert.assertTrue(channel.isOpen()); + handler.exceptionCaught(channel.pipeline().firstContext(), new EOFException()); + Assert.assertFalse(channel.isOpen()); + } + + private static class MessageInWrapper + { + MessageIn messageIn; + int id; + + final BiConsumer<MessageIn, Integer> messageConsumer = (messageIn, integer) -> + { + this.messageIn = messageIn; + this.id = integer; + }; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java b/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java new file mode 100644 index 0000000..566dfdb --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.HashMap; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import com.sun.org.apache.bcel.internal.generic.DDIV; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.UnsupportedMessageTypeException; +import io.netty.handler.timeout.IdleStateEvent; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.net.MessageOut; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.tracing.Tracing; +import org.apache.cassandra.utils.UUIDGen; + +public class MessageOutHandlerTest +{ + private static final int MESSAGING_VERSION = MessagingService.current_version; + + private ChannelWriter channelWriter; + private EmbeddedChannel channel; + private MessageOutHandler handler; + + @BeforeClass + public static void before() + { + DatabaseDescriptor.daemonInitialization(); + DatabaseDescriptor.createAllDirectories(); + } + + @Before + public void setup() + { + setup(MessageOutHandler.AUTO_FLUSH_THRESHOLD); + } + + private void setup(int flushThreshold) + { + OutboundConnectionIdentifier connectionId = OutboundConnectionIdentifier.small(new InetSocketAddress("127.0.0.1", 0), + new InetSocketAddress("127.0.0.2", 0)); + OutboundMessagingConnection omc = new NonSendingOutboundMessagingConnection(connectionId, null, Optional.empty()); + channel = new EmbeddedChannel(); + channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()); + handler = new MessageOutHandler(connectionId, MESSAGING_VERSION, channelWriter, () -> null, flushThreshold); + channel.pipeline().addLast(handler); + } + + @Test + public void write_NoFlush() throws ExecutionException, InterruptedException, TimeoutException + { + MessageOut message = new MessageOut(MessagingService.Verb.ECHO); + ChannelFuture future = channel.write(new QueuedMessage(message, 42)); + Assert.assertTrue(!future.isDone()); + Assert.assertFalse(channel.releaseOutbound()); + } + + @Test + public void write_WithFlush() throws ExecutionException, InterruptedException, TimeoutException + { + setup(1); + MessageOut message = new MessageOut(MessagingService.Verb.ECHO); + ChannelFuture future = channel.write(new QueuedMessage(message, 42)); + Assert.assertTrue(future.isSuccess()); + Assert.assertTrue(channel.releaseOutbound()); + } + + @Test + public void serializeMessage() throws IOException + { + channelWriter.pendingMessageCount.set(1); + QueuedMessage msg = new QueuedMessage(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), 1); + ChannelFuture future = channel.writeAndFlush(msg); + + Assert.assertTrue(future.isSuccess()); + Assert.assertTrue(1 <= channel.outboundMessages().size()); + Assert.assertTrue(channel.releaseOutbound()); + } + + @Test + public void wrongMessageType() + { + ChannelPromise promise = new DefaultChannelPromise(channel); + Assert.assertFalse(handler.isMessageValid("this is the wrong message type", promise)); + + Assert.assertFalse(promise.isSuccess()); + Assert.assertNotNull(promise.cause()); + Assert.assertSame(UnsupportedMessageTypeException.class, promise.cause().getClass()); + } + + @Test + public void unexpiredMessage() + { + QueuedMessage msg = new QueuedMessage(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), 1); + ChannelPromise promise = new DefaultChannelPromise(channel); + Assert.assertTrue(handler.isMessageValid(msg, promise)); + + // we won't know if it was successful yet, but we'll know if it's a failure because cause will be set + Assert.assertNull(promise.cause()); + } + + @Test + public void expiredMessage() + { + QueuedMessage msg = new QueuedMessage(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), 1, 0, true, true); + ChannelPromise promise = new DefaultChannelPromise(channel); + Assert.assertFalse(handler.isMessageValid(msg, promise)); + + Assert.assertFalse(promise.isSuccess()); + Assert.assertNotNull(promise.cause()); + Assert.assertSame(ExpiredException.class, promise.cause().getClass()); + Assert.assertTrue(channel.outboundMessages().isEmpty()); + } + + @Test + public void write_MessageTooLarge() + { + write_BadMessageSize(Integer.MAX_VALUE + 1); + } + + @Test + public void write_MessageSizeIsBananas() + { + write_BadMessageSize(Integer.MIN_VALUE + 10000); + } + + private void write_BadMessageSize(long size) + { + IVersionedSerializer<Object> serializer = new IVersionedSerializer<Object>() + { + public void serialize(Object o, DataOutputPlus out, int version) + { } + + public Object deserialize(DataInputPlus in, int version) + { + return null; + } + + public long serializedSize(Object o, int version) + { + return size; + } + }; + MessageOut message = new MessageOut(MessagingService.Verb.UNUSED_5, "payload", serializer); + ChannelFuture future = channel.write(new QueuedMessage(message, 42)); + Throwable t = future.cause(); + Assert.assertNotNull(t); + Assert.assertSame(IllegalStateException.class, t.getClass()); + Assert.assertTrue(channel.isOpen()); + Assert.assertFalse(channel.releaseOutbound()); + } + + @Test + public void writeForceExceptionPath() + { + IVersionedSerializer<Object> serializer = new IVersionedSerializer<Object>() + { + public void serialize(Object o, DataOutputPlus out, int version) + { + throw new RuntimeException("this exception is part of the test - DON'T PANIC"); + } + + public Object deserialize(DataInputPlus in, int version) + { + return null; + } + + public long serializedSize(Object o, int version) + { + return 42; + } + }; + MessageOut message = new MessageOut(MessagingService.Verb.UNUSED_5, "payload", serializer); + ChannelFuture future = channel.write(new QueuedMessage(message, 42)); + Throwable t = future.cause(); + Assert.assertNotNull(t); + Assert.assertFalse(channel.isOpen()); + Assert.assertFalse(channel.releaseOutbound()); + } + + @Test + public void captureTracingInfo_ForceException() + { + MessageOut message = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE) + .withParameter(Tracing.TRACE_HEADER, new byte[9]); + handler.captureTracingInfo(new QueuedMessage(message, 42)); + } + + @Test + public void captureTracingInfo_UnknownSession() + { + UUID uuid = UUID.randomUUID(); + MessageOut message = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE) + .withParameter(Tracing.TRACE_HEADER, UUIDGen.decompose(uuid)); + handler.captureTracingInfo(new QueuedMessage(message, 42)); + } + + @Test + public void captureTracingInfo_KnownSession() + { + Tracing.instance.newSession(new HashMap<>()); + MessageOut message = new MessageOut(MessagingService.Verb.REQUEST_RESPONSE); + handler.captureTracingInfo(new QueuedMessage(message, 42)); + } + + @Test + public void userEventTriggered_RandomObject() + { + Assert.assertTrue(channel.isOpen()); + ChannelUserEventSender sender = new ChannelUserEventSender(); + channel.pipeline().addFirst(sender); + sender.sendEvent("ThisIsAFakeEvent"); + Assert.assertTrue(channel.isOpen()); + } + + @Test + public void userEventTriggered_Idle_NoPendingBytes() + { + Assert.assertTrue(channel.isOpen()); + ChannelUserEventSender sender = new ChannelUserEventSender(); + channel.pipeline().addFirst(sender); + sender.sendEvent(IdleStateEvent.WRITER_IDLE_STATE_EVENT); + Assert.assertTrue(channel.isOpen()); + } + + @Test + public void userEventTriggered_Idle_WithPendingBytes() + { + Assert.assertTrue(channel.isOpen()); + ChannelUserEventSender sender = new ChannelUserEventSender(); + channel.pipeline().addFirst(sender); + + MessageOut message = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE); + channel.writeOutbound(new QueuedMessage(message, 42)); + sender.sendEvent(IdleStateEvent.WRITER_IDLE_STATE_EVENT); + Assert.assertFalse(channel.isOpen()); + } + + private static class ChannelUserEventSender extends ChannelOutboundHandlerAdapter + { + private ChannelHandlerContext ctx; + + @Override + public void handlerAdded(final ChannelHandlerContext ctx) throws Exception + { + this.ctx = ctx; + } + + private void sendEvent(Object event) + { + ctx.fireUserEventTriggered(event); + } + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java b/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java new file mode 100644 index 0000000..c4cc7e6 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Optional; + +import com.google.common.net.InetAddresses; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.group.ChannelGroup; +import io.netty.channel.group.DefaultChannelGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.concurrent.GlobalEventExecutor; +import org.apache.cassandra.auth.AllowAllInternodeAuthenticator; +import org.apache.cassandra.auth.IInternodeAuthenticator; +import org.apache.cassandra.config.Config; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; +import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions.InternodeEncryption; +import org.apache.cassandra.exceptions.ConfigurationException; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.async.NettyFactory.InboundInitializer; +import org.apache.cassandra.net.async.NettyFactory.OutboundInitializer; +import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.NativeLibrary; + +public class NettyFactoryTest +{ + private static final InetSocketAddress LOCAL_ADDR = new InetSocketAddress("127.0.0.1", 9876); + private static final InetSocketAddress REMOTE_ADDR = new InetSocketAddress("127.0.0.2", 9876); + private static final int receiveBufferSize = 1 << 16; + private static final IInternodeAuthenticator AUTHENTICATOR = new AllowAllInternodeAuthenticator(); + + private ChannelGroup channelGroup; + private NettyFactory factory; + + @BeforeClass + public static void before() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setUp() + { + channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); + } + + @After + public void tearDown() + { + if (factory != null) + factory.close(); + } + + @Test + public void createServerChannel_Epoll() + { + Channel inboundChannel = createServerChannel(true); + if (inboundChannel == null) + return; + Assert.assertEquals(EpollServerSocketChannel.class, inboundChannel.getClass()); + inboundChannel.close(); + } + + private Channel createServerChannel(boolean useEpoll) + { + InboundInitializer inboundInitializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup); + factory = new NettyFactory(useEpoll); + + try + { + return factory.createInboundChannel(LOCAL_ADDR, inboundInitializer, receiveBufferSize); + } + catch (Exception e) + { + if (NativeLibrary.osType == NativeLibrary.OSType.LINUX) + throw e; + + return null; + } + } + + @Test + public void createServerChannel_Nio() + { + Channel inboundChannel = createServerChannel(false); + Assert.assertNotNull("we should always be able to get a NIO channel", inboundChannel); + Assert.assertEquals(NioServerSocketChannel.class, inboundChannel.getClass()); + inboundChannel.close(); + } + + @Test(expected = ConfigurationException.class) + public void createServerChannel_SecondAttemptToBind() + { + Channel inboundChannel = null; + try + { + InetSocketAddress addr = new InetSocketAddress("127.0.0.1", 9876); + InboundInitializer inboundInitializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup); + inboundChannel = NettyFactory.instance.createInboundChannel(LOCAL_ADDR, inboundInitializer, receiveBufferSize); + NettyFactory.instance.createInboundChannel(LOCAL_ADDR, inboundInitializer, receiveBufferSize); + } + finally + { + if (inboundChannel != null) + inboundChannel.close(); + } + } + + @Test(expected = ConfigurationException.class) + public void createServerChannel_UnbindableAddress() + { + InetSocketAddress addr = new InetSocketAddress("1.1.1.1", 9876); + InboundInitializer inboundInitializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup); + NettyFactory.instance.createInboundChannel(addr, inboundInitializer, receiveBufferSize); + } + + @Test + public void deterineAcceptGroupSize() + { + Assert.assertEquals(1, NettyFactory.determineAcceptGroupSize(InternodeEncryption.none)); + Assert.assertEquals(1, NettyFactory.determineAcceptGroupSize(InternodeEncryption.all)); + Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(InternodeEncryption.rack)); + Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(InternodeEncryption.dc)); + + InetAddress originalBroadcastAddr = FBUtilities.getBroadcastAddress(); + try + { + FBUtilities.setBroadcastInetAddress(InetAddresses.increment(FBUtilities.getLocalAddress())); + DatabaseDescriptor.setListenOnBroadcastAddress(true); + + Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(InternodeEncryption.none)); + Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(InternodeEncryption.all)); + Assert.assertEquals(4, NettyFactory.determineAcceptGroupSize(InternodeEncryption.rack)); + Assert.assertEquals(4, NettyFactory.determineAcceptGroupSize(InternodeEncryption.dc)); + } + finally + { + FBUtilities.setBroadcastInetAddress(originalBroadcastAddr); + DatabaseDescriptor.setListenOnBroadcastAddress(false); + } + } + + @Test + public void getEventLoopGroup_EpollWithIoRatioBoost() + { + getEventLoopGroup_Epoll(true); + } + + private EpollEventLoopGroup getEventLoopGroup_Epoll(boolean ioBoost) + { + EventLoopGroup eventLoopGroup; + try + { + eventLoopGroup = NettyFactory.getEventLoopGroup(true, 1, "testEventLoopGroup", ioBoost); + } + catch (Exception e) + { + if (NativeLibrary.osType == NativeLibrary.OSType.LINUX) + throw e; + + // ignore as epoll is only available on linux platforms, so don't fail the test on other OSes + return null; + } + + Assert.assertTrue(eventLoopGroup instanceof EpollEventLoopGroup); + return (EpollEventLoopGroup) eventLoopGroup; + } + + @Test + public void getEventLoopGroup_EpollWithoutIoRatioBoost() + { + getEventLoopGroup_Epoll(false); + } + + @Test + public void getEventLoopGroup_NioWithoutIoRatioBoost() + { + getEventLoopGroup_Nio(true); + } + + private NioEventLoopGroup getEventLoopGroup_Nio(boolean ioBoost) + { + EventLoopGroup eventLoopGroup = NettyFactory.getEventLoopGroup(false, 1, "testEventLoopGroup", ioBoost); + Assert.assertTrue(eventLoopGroup instanceof NioEventLoopGroup); + return (NioEventLoopGroup) eventLoopGroup; + } + + @Test + public void getEventLoopGroup_NioWithIoRatioBoost() + { + getEventLoopGroup_Nio(true); + } + + @Test + public void createOutboundBootstrap_Epoll() + { + Bootstrap bootstrap = createOutboundBootstrap(true); + Assert.assertEquals(EpollEventLoopGroup.class, bootstrap.config().group().getClass()); + } + + private Bootstrap createOutboundBootstrap(boolean useEpoll) + { + factory = new NettyFactory(useEpoll); + OutboundConnectionIdentifier id = OutboundConnectionIdentifier.gossip(LOCAL_ADDR, REMOTE_ADDR); + OutboundConnectionParams params = OutboundConnectionParams.builder() + .connectionId(id) + .coalescingStrategy(Optional.empty()) + .protocolVersion(MessagingService.current_version) + .build(); + return factory.createOutboundBootstrap(params); + } + + @Test + public void createOutboundBootstrap_Nio() + { + Bootstrap bootstrap = createOutboundBootstrap(false); + Assert.assertEquals(NioEventLoopGroup.class, bootstrap.config().group().getClass()); + } + + @Test + public void createInboundInitializer_WithoutSsl() throws Exception + { + InboundInitializer initializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup); + NioSocketChannel channel = new NioSocketChannel(); + initializer.initChannel(channel); + Assert.assertNull(channel.pipeline().get(SslHandler.class)); + } + + private ServerEncryptionOptions encOptions() + { + ServerEncryptionOptions encryptionOptions; + encryptionOptions = new ServerEncryptionOptions(); + encryptionOptions.keystore = "test/conf/cassandra_ssl_test.keystore"; + encryptionOptions.keystore_password = "cassandra"; + encryptionOptions.truststore = "test/conf/cassandra_ssl_test.truststore"; + encryptionOptions.truststore_password = "cassandra"; + encryptionOptions.require_client_auth = false; + encryptionOptions.cipher_suites = new String[] {"TLS_RSA_WITH_AES_128_CBC_SHA"}; + return encryptionOptions; + } + @Test + public void createInboundInitializer_WithSsl() throws Exception + { + ServerEncryptionOptions encryptionOptions = encOptions(); + InboundInitializer initializer = new InboundInitializer(AUTHENTICATOR, encryptionOptions, channelGroup); + NioSocketChannel channel = new NioSocketChannel(); + Assert.assertNull(channel.pipeline().get(SslHandler.class)); + initializer.initChannel(channel); + Assert.assertNotNull(channel.pipeline().get(SslHandler.class)); + } + + @Test + public void createOutboundInitializer_WithSsl() throws Exception + { + OutboundConnectionIdentifier id = OutboundConnectionIdentifier.gossip(LOCAL_ADDR, REMOTE_ADDR); + OutboundConnectionParams params = OutboundConnectionParams.builder() + .connectionId(id) + .encryptionOptions(encOptions()) + .protocolVersion(MessagingService.current_version) + .build(); + OutboundInitializer outboundInitializer = new OutboundInitializer(params); + NioSocketChannel channel = new NioSocketChannel(); + Assert.assertNull(channel.pipeline().get(SslHandler.class)); + outboundInitializer.initChannel(channel); + Assert.assertNotNull(channel.pipeline().get(SslHandler.class)); + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java b/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java new file mode 100644 index 0000000..b0b15b8 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.util.Optional; + +import org.apache.cassandra.auth.AllowAllInternodeAuthenticator; +import org.apache.cassandra.config.EncryptionOptions; +import org.apache.cassandra.utils.CoalescingStrategies; + +class NonSendingOutboundMessagingConnection extends OutboundMessagingConnection +{ + boolean sendMessageInvoked; + + NonSendingOutboundMessagingConnection(OutboundConnectionIdentifier connectionId, EncryptionOptions.ServerEncryptionOptions encryptionOptions, Optional<CoalescingStrategies.CoalescingStrategy> coalescingStrategy) + { + super(connectionId, encryptionOptions, coalescingStrategy, new AllowAllInternodeAuthenticator()); + } + + @Override + boolean sendMessage(QueuedMessage queuedMessage) + { + sendMessageInvoked = true; + return true; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java b/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java new file mode 100644 index 0000000..0ce4968 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import org.junit.Test; + +public class OutboundConnectionParamsTest +{ + @Test (expected = IllegalArgumentException.class) + public void build_SendSizeLessThanZero() + { + OutboundConnectionParams.builder().sendBufferSize(-1).build(); + } + + @Test (expected = IllegalArgumentException.class) + public void build_SendSizeHuge() + { + OutboundConnectionParams.builder().sendBufferSize(1 << 30).build(); + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java b/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java new file mode 100644 index 0000000..f8bfab1 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net.async; + +import java.net.InetSocketAddress; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.compression.Lz4FrameDecoder; +import io.netty.handler.codec.compression.Lz4FrameEncoder; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.async.HandshakeProtocol.SecondHandshakeMessage; +import org.apache.cassandra.net.async.OutboundHandshakeHandler.HandshakeResult; + +import static org.apache.cassandra.net.async.OutboundHandshakeHandler.HandshakeResult.UNKNOWN_PROTOCOL_VERSION; + +public class OutboundHandshakeHandlerTest +{ + private static final int MESSAGING_VERSION = MessagingService.current_version; + private static final InetSocketAddress localAddr = new InetSocketAddress("127.0.0.1", 0); + private static final InetSocketAddress remoteAddr = new InetSocketAddress("127.0.0.2", 0); + private static final String HANDLER_NAME = "clientHandshakeHandler"; + + private EmbeddedChannel channel; + private OutboundHandshakeHandler handler; + private OutboundConnectionIdentifier connectionId; + private OutboundConnectionParams params; + private CallbackHandler callbackHandler; + private ByteBuf buf; + + @BeforeClass + public static void before() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setup() + { + channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter()); + connectionId = OutboundConnectionIdentifier.small(localAddr, remoteAddr); + callbackHandler = new CallbackHandler(); + params = OutboundConnectionParams.builder() + .connectionId(connectionId) + .callback(handshakeResult -> callbackHandler.receive(handshakeResult)) + .mode(NettyFactory.Mode.MESSAGING) + .protocolVersion(MessagingService.current_version) + .coalescingStrategy(Optional.empty()) + .build(); + handler = new OutboundHandshakeHandler(params); + channel.pipeline().addFirst(HANDLER_NAME, handler); + } + + @After + public void tearDown() + { + if (buf != null && buf.refCnt() > 0) + buf.release(); + Assert.assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void decode_SmallInput() throws Exception + { + buf = Unpooled.buffer(2, 2); + List<Object> out = new LinkedList<>(); + handler.decode(channel.pipeline().firstContext(), buf, out); + Assert.assertEquals(0, buf.readerIndex()); + Assert.assertTrue(out.isEmpty()); + } + + @Test + public void decode_HappyPath() throws Exception + { + buf = new SecondHandshakeMessage(MESSAGING_VERSION).encode(PooledByteBufAllocator.DEFAULT); + channel.writeInbound(buf); + Assert.assertEquals(1, channel.outboundMessages().size()); + Assert.assertTrue(channel.isOpen()); + Assert.assertTrue(channel.releaseOutbound()); // throw away any responses from decode() + + Assert.assertEquals(MESSAGING_VERSION, callbackHandler.result.negotiatedMessagingVersion); + Assert.assertEquals(HandshakeResult.Outcome.SUCCESS, callbackHandler.result.outcome); + } + + @Test + public void decode_HappyPathThrowsException() throws Exception + { + callbackHandler.failOnCallback = true; + buf = new SecondHandshakeMessage(MESSAGING_VERSION).encode(PooledByteBufAllocator.DEFAULT); + channel.writeInbound(buf); + Assert.assertFalse(channel.isOpen()); + Assert.assertEquals(1, channel.outboundMessages().size()); + Assert.assertTrue(channel.releaseOutbound()); // throw away any responses from decode() + + Assert.assertEquals(UNKNOWN_PROTOCOL_VERSION, callbackHandler.result.negotiatedMessagingVersion); + Assert.assertEquals(HandshakeResult.Outcome.NEGOTIATION_FAILURE, callbackHandler.result.outcome); + } + + @Test + public void decode_ReceivedLowerMsgVersion() throws Exception + { + int msgVersion = MESSAGING_VERSION - 1; + buf = new SecondHandshakeMessage(msgVersion).encode(PooledByteBufAllocator.DEFAULT); + channel.writeInbound(buf); + Assert.assertTrue(channel.inboundMessages().isEmpty()); + + Assert.assertEquals(msgVersion, callbackHandler.result.negotiatedMessagingVersion); + Assert.assertEquals(HandshakeResult.Outcome.DISCONNECT, callbackHandler.result.outcome); + Assert.assertFalse(channel.isOpen()); + Assert.assertTrue(channel.outboundMessages().isEmpty()); + } + + @Test + public void decode_ReceivedHigherMsgVersion() throws Exception + { + int msgVersion = MESSAGING_VERSION - 1; + channel.pipeline().remove(HANDLER_NAME); + params = OutboundConnectionParams.builder() + .connectionId(connectionId) + .callback(handshakeResult -> callbackHandler.receive(handshakeResult)) + .mode(NettyFactory.Mode.MESSAGING) + .protocolVersion(msgVersion) + .coalescingStrategy(Optional.empty()) + .build(); + handler = new OutboundHandshakeHandler(params); + channel.pipeline().addFirst(HANDLER_NAME, handler); + buf = new SecondHandshakeMessage(MESSAGING_VERSION).encode(PooledByteBufAllocator.DEFAULT); + channel.writeInbound(buf); + + Assert.assertEquals(MESSAGING_VERSION, callbackHandler.result.negotiatedMessagingVersion); + Assert.assertEquals(HandshakeResult.Outcome.DISCONNECT, callbackHandler.result.outcome); + } + + @Test + public void setupPipeline_WithCompression() + { + EmbeddedChannel chan = new EmbeddedChannel(new ChannelOutboundHandlerAdapter()); + ChannelPipeline pipeline = chan.pipeline(); + params = OutboundConnectionParams.builder(params).compress(true).protocolVersion(MessagingService.current_version).build(); + handler = new OutboundHandshakeHandler(params); + pipeline.addFirst(handler); + handler.setupPipeline(chan, MESSAGING_VERSION); + Assert.assertNotNull(pipeline.get(Lz4FrameEncoder.class)); + Assert.assertNull(pipeline.get(Lz4FrameDecoder.class)); + Assert.assertNotNull(pipeline.get(MessageOutHandler.class)); + } + + @Test + public void setupPipeline_NoCompression() + { + EmbeddedChannel chan = new EmbeddedChannel(new ChannelOutboundHandlerAdapter()); + ChannelPipeline pipeline = chan.pipeline(); + params = OutboundConnectionParams.builder(params).compress(false).protocolVersion(MessagingService.current_version).build(); + handler = new OutboundHandshakeHandler(params); + pipeline.addFirst(handler); + handler.setupPipeline(chan, MESSAGING_VERSION); + Assert.assertNull(pipeline.get(Lz4FrameEncoder.class)); + Assert.assertNull(pipeline.get(Lz4FrameDecoder.class)); + Assert.assertNotNull(pipeline.get(MessageOutHandler.class)); + } + + private static class CallbackHandler + { + boolean failOnCallback; + HandshakeResult result; + + Void receive(HandshakeResult handshakeResult) + { + if (failOnCallback) + { + // only fail the first callback + failOnCallback = false; + throw new RuntimeException("this exception is expected in the test - DON'T PANIC"); + } + result = handshakeResult; + return null; + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org