You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by ja...@apache.org on 2017/08/22 20:55:13 UTC
[06/11] cassandra git commit: switch internode messaging to netty
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java b/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java
new file mode 100644
index 0000000..128fe4b
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/ChannelWriterTest.java
@@ -0,0 +1,312 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.async;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.Optional;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.WriteBufferWaterMark;
+import io.netty.channel.embedded.EmbeddedChannel;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.net.MessageOut;
+import org.apache.cassandra.net.MessagingService;
+import org.apache.cassandra.net.async.ChannelWriter.CoalescingChannelWriter;
+import org.apache.cassandra.utils.CoalescingStrategies;
+import org.apache.cassandra.utils.CoalescingStrategies.CoalescingStrategy;
+
+import static org.apache.cassandra.net.MessagingService.Verb.ECHO;
+
+/**
+ * with the write_Coalescing_* methods, if there's data in the channel.unsafe().outboundBuffer()
+ * it means that there's something in the channel that hasn't yet been flushed to the transport (socket).
+ * once a flush occurs, there will be an entry in EmbeddedChannel's outboundQueue. those two facts are leveraged in these tests.
+ */
+public class ChannelWriterTest
+{
+ private static final int COALESCE_WINDOW_MS = 10;
+
+ private EmbeddedChannel channel;
+ private ChannelWriter channelWriter;
+ private NonSendingOutboundMessagingConnection omc;
+ private Optional<CoalescingStrategy> coalescingStrategy;
+
+ @BeforeClass
+ public static void before()
+ {
+ DatabaseDescriptor.daemonInitialization();
+ }
+
+ @Before
+ public void setup()
+ {
+ OutboundConnectionIdentifier id = OutboundConnectionIdentifier.small(new InetSocketAddress("127.0.0.1", 0),
+ new InetSocketAddress("127.0.0.2", 0));
+ channel = new EmbeddedChannel();
+ omc = new NonSendingOutboundMessagingConnection(id, null, Optional.empty());
+ channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty());
+ channel.pipeline().addFirst(new MessageOutHandler(id, MessagingService.current_version, channelWriter, () -> null));
+ coalescingStrategy = CoalescingStrategies.newCoalescingStrategy(CoalescingStrategies.Strategy.FIXED.name(), COALESCE_WINDOW_MS, null, "test");
+ }
+
+ @Test
+ public void create_nonCoalescing()
+ {
+ Assert.assertSame(ChannelWriter.SimpleChannelWriter.class, ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty()).getClass());
+ }
+
+ @Test
+ public void create_Coalescing()
+ {
+ Assert.assertSame(CoalescingChannelWriter.class, ChannelWriter.create(channel, omc::handleMessageResult, coalescingStrategy).getClass());
+ }
+
+ @Test
+ public void write_IsWritable()
+ {
+ Assert.assertTrue(channel.isWritable());
+ Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true));
+ Assert.assertTrue(channel.isWritable());
+ Assert.assertTrue(channel.releaseOutbound());
+ }
+
+ @Test
+ public void write_NotWritable()
+ {
+ channel.config().setOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1, 2));
+
+ // send one message through, which will trigger the writability check (and turn it off)
+ Assert.assertTrue(channel.isWritable());
+ ByteBuf buf = channel.alloc().buffer(8, 8);
+ channel.unsafe().outboundBuffer().addMessage(buf, buf.capacity(), channel.newPromise());
+ Assert.assertFalse(channel.isWritable());
+ Assert.assertFalse(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true));
+ Assert.assertFalse(channel.isWritable());
+ Assert.assertFalse(channel.releaseOutbound());
+ buf.release();
+ }
+
+ @Test
+ public void write_NotWritableButWriteAnyway()
+ {
+ channel.config().setOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1, 2));
+
+ // send one message through, which will trigger the writability check (and turn it off)
+ Assert.assertTrue(channel.isWritable());
+ ByteBuf buf = channel.alloc().buffer(8, 8);
+ channel.unsafe().outboundBuffer().addMessage(buf, buf.capacity(), channel.newPromise());
+ Assert.assertFalse(channel.isWritable());
+ Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), false));
+ Assert.assertTrue(channel.isWritable());
+ Assert.assertTrue(channel.releaseOutbound());
+ }
+
+ @Test
+ public void write_Coalescing_LostRaceForFlushTask()
+ {
+ CoalescingChannelWriter channelWriter = resetEnvForCoalescing(DatabaseDescriptor.getOtcCoalescingEnoughCoalescedMessages());
+ channelWriter.scheduledFlush.set(true);
+ Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0);
+ Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true));
+ Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() > 0);
+ Assert.assertFalse(channel.releaseOutbound());
+ Assert.assertTrue(channelWriter.scheduledFlush.get());
+ }
+
+ @Test
+ public void write_Coalescing_HitMinMessageCountForImmediateCoalesce()
+ {
+ CoalescingChannelWriter channelWriter = resetEnvForCoalescing(1);
+
+ Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0);
+ Assert.assertFalse(channelWriter.scheduledFlush.get());
+ Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true));
+
+ Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0);
+ Assert.assertTrue(channel.releaseOutbound());
+ Assert.assertFalse(channelWriter.scheduledFlush.get());
+ }
+
+ @Test
+ public void write_Coalescing_ScheduleFlushTask()
+ {
+ CoalescingChannelWriter channelWriter = resetEnvForCoalescing(DatabaseDescriptor.getOtcCoalescingEnoughCoalescedMessages());
+
+ Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0);
+ Assert.assertFalse(channelWriter.scheduledFlush.get());
+ Assert.assertTrue(channelWriter.write(new QueuedMessage(new MessageOut<>(ECHO), 42), true));
+
+ Assert.assertTrue(channelWriter.scheduledFlush.get());
+ Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() > 0);
+ Assert.assertTrue(channelWriter.scheduledFlush.get());
+
+ // this unfortunately know a little too much about how the sausage is made in CoalescingChannelWriter :-/
+ channel.runScheduledPendingTasks();
+ channel.runPendingTasks();
+ Assert.assertTrue(channel.unsafe().outboundBuffer().totalPendingWriteBytes() == 0);
+ Assert.assertFalse(channelWriter.scheduledFlush.get());
+ Assert.assertTrue(channel.releaseOutbound());
+ }
+
+ private CoalescingChannelWriter resetEnvForCoalescing(int minMessagesForCoalesce)
+ {
+ channel = new EmbeddedChannel();
+ CoalescingChannelWriter cw = new CoalescingChannelWriter(channel, omc::handleMessageResult, coalescingStrategy.get(), minMessagesForCoalesce);
+ channel.pipeline().addFirst(new ChannelOutboundHandlerAdapter()
+ {
+ public void flush(ChannelHandlerContext ctx) throws Exception
+ {
+ cw.onTriggeredFlush(ctx);
+ }
+ });
+ omc.setChannelWriter(cw);
+ return cw;
+ }
+
+ @Test
+ public void writeBacklog_Empty()
+ {
+ BlockingQueue<QueuedMessage> queue = new LinkedBlockingQueue<>();
+ Assert.assertEquals(0, channelWriter.writeBacklog(queue, false));
+ Assert.assertFalse(channel.releaseOutbound());
+ }
+
+ @Test
+ public void writeBacklog_ChannelNotWritable()
+ {
+ Assert.assertTrue(channel.isWritable());
+ // force the channel to be non writable
+ channel.config().setOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1, 2));
+ ByteBuf buf = channel.alloc().buffer(8, 8);
+ channel.unsafe().outboundBuffer().addMessage(buf, buf.capacity(), channel.newPromise());
+ Assert.assertFalse(channel.isWritable());
+
+ Assert.assertEquals(0, channelWriter.writeBacklog(new LinkedBlockingQueue<>(), false));
+ Assert.assertFalse(channel.releaseOutbound());
+ Assert.assertFalse(channel.isWritable());
+ buf.release();
+ }
+
+ @Test
+ public void writeBacklog_NotEmpty()
+ {
+ BlockingQueue<QueuedMessage> queue = new LinkedBlockingQueue<>();
+ int count = 12;
+ for (int i = 0; i < count; i++)
+ queue.offer(new QueuedMessage(new MessageOut<>(ECHO), i));
+ Assert.assertEquals(count, channelWriter.writeBacklog(queue, false));
+ Assert.assertTrue(channel.releaseOutbound());
+ }
+
+ @Test
+ public void close()
+ {
+ Assert.assertFalse(channelWriter.isClosed());
+ Assert.assertTrue(channel.isOpen());
+ channelWriter.close();
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertTrue(channelWriter.isClosed());
+ }
+
+ @Test
+ public void softClose()
+ {
+ Assert.assertFalse(channelWriter.isClosed());
+ Assert.assertTrue(channel.isOpen());
+ channelWriter.softClose();
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertTrue(channelWriter.isClosed());
+ }
+
+ @Test
+ public void handleMessagePromise_FutureIsCancelled()
+ {
+ ChannelPromise promise = channel.newPromise();
+ promise.cancel(false);
+ channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1), true);
+ Assert.assertTrue(channel.isActive());
+ Assert.assertEquals(1, omc.getCompletedMessages().longValue());
+ Assert.assertEquals(0, omc.getDroppedMessages().longValue());
+ }
+
+ @Test
+ public void handleMessagePromise_ExpiredException_DoNotRetryMsg()
+ {
+ ChannelPromise promise = channel.newPromise();
+ promise.setFailure(new ExpiredException());
+
+ channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1), true);
+ Assert.assertTrue(channel.isActive());
+ Assert.assertEquals(1, omc.getCompletedMessages().longValue());
+ Assert.assertEquals(1, omc.getDroppedMessages().longValue());
+ Assert.assertFalse(omc.sendMessageInvoked);
+ }
+
+ @Test
+ public void handleMessagePromise_NonIOException()
+ {
+ ChannelPromise promise = channel.newPromise();
+ promise.setFailure(new NullPointerException("this is a test"));
+ channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1), true);
+ Assert.assertTrue(channel.isActive());
+ Assert.assertEquals(1, omc.getCompletedMessages().longValue());
+ Assert.assertEquals(0, omc.getDroppedMessages().longValue());
+ Assert.assertFalse(omc.sendMessageInvoked);
+ }
+
+ @Test
+ public void handleMessagePromise_IOException_ChannelNotClosed_RetryMsg()
+ {
+ ChannelPromise promise = channel.newPromise();
+ promise.setFailure(new IOException("this is a test"));
+ Assert.assertTrue(channel.isActive());
+ channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1, 0, true, true), true);
+
+ Assert.assertFalse(channel.isActive());
+ Assert.assertEquals(1, omc.getCompletedMessages().longValue());
+ Assert.assertEquals(0, omc.getDroppedMessages().longValue());
+ Assert.assertTrue(omc.sendMessageInvoked);
+ }
+
+ @Test
+ public void handleMessagePromise_Cancelled()
+ {
+ ChannelPromise promise = channel.newPromise();
+ promise.cancel(false);
+ Assert.assertTrue(channel.isActive());
+ channelWriter.handleMessageFuture(promise, new QueuedMessage(new MessageOut<>(ECHO), 1, 0, true, true), true);
+
+ Assert.assertTrue(channel.isActive());
+ Assert.assertEquals(1, omc.getCompletedMessages().longValue());
+ Assert.assertEquals(0, omc.getDroppedMessages().longValue());
+ Assert.assertFalse(omc.sendMessageInvoked);
+ }
+}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java b/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java
new file mode 100644
index 0000000..fa6e2b5
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/HandshakeHandlersTest.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.async;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.util.Optional;
+
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.channel.embedded.EmbeddedChannel;
+import org.apache.cassandra.SchemaLoader;
+import org.apache.cassandra.auth.AllowAllInternodeAuthenticator;
+import org.apache.cassandra.db.ColumnFamilyStore;
+import org.apache.cassandra.db.Keyspace;
+import org.apache.cassandra.db.Mutation;
+import org.apache.cassandra.db.RowUpdateBuilder;
+import org.apache.cassandra.db.compaction.CompactionManager;
+import org.apache.cassandra.db.marshal.AsciiType;
+import org.apache.cassandra.db.marshal.BytesType;
+import org.apache.cassandra.exceptions.ConfigurationException;
+import org.apache.cassandra.net.MessageOut;
+import org.apache.cassandra.net.MessagingService;
+import org.apache.cassandra.schema.KeyspaceParams;
+
+import static org.apache.cassandra.net.async.InboundHandshakeHandler.State.MESSAGING_HANDSHAKE_COMPLETE;
+import static org.apache.cassandra.net.async.OutboundMessagingConnection.State.READY;
+
+public class HandshakeHandlersTest
+{
+ private static final String KEYSPACE1 = "NettyPipilineTest";
+ private static final String STANDARD1 = "Standard1";
+
+ private static final InetSocketAddress LOCAL_ADDR = new InetSocketAddress("127.0.0.1", 9999);
+ private static final InetSocketAddress REMOTE_ADDR = new InetSocketAddress("127.0.0.2", 9999);
+ private static final int MESSAGING_VERSION = MessagingService.current_version;
+ private static final OutboundConnectionIdentifier connectionId = OutboundConnectionIdentifier.small(LOCAL_ADDR, REMOTE_ADDR);
+
+ @BeforeClass
+ public static void beforeClass() throws ConfigurationException
+ {
+ SchemaLoader.prepareServer();
+ SchemaLoader.createKeyspace(KEYSPACE1,
+ KeyspaceParams.simple(1),
+ SchemaLoader.standardCFMD(KEYSPACE1, STANDARD1, 0, AsciiType.instance, BytesType.instance));
+ CompactionManager.instance.disableAutoCompaction();
+ }
+
+ @Test
+ public void handshake_HappyPath()
+ {
+ // beacuse both CHH & SHH are ChannelInboundHandlers, we can't use the same EmbeddedChannel to handle them
+ InboundHandshakeHandler inboundHandshakeHandler = new InboundHandshakeHandler(new TestAuthenticator(true));
+ EmbeddedChannel inboundChannel = new EmbeddedChannel(inboundHandshakeHandler);
+
+ OutboundMessagingConnection imc = new OutboundMessagingConnection(connectionId, null, Optional.empty(), new AllowAllInternodeAuthenticator());
+ OutboundConnectionParams params = OutboundConnectionParams.builder()
+ .connectionId(connectionId)
+ .callback(imc::finishHandshake)
+ .mode(NettyFactory.Mode.MESSAGING)
+ .protocolVersion(MessagingService.current_version)
+ .coalescingStrategy(Optional.empty())
+ .build();
+ OutboundHandshakeHandler outboundHandshakeHandler = new OutboundHandshakeHandler(params);
+ EmbeddedChannel outboundChannel = new EmbeddedChannel(outboundHandshakeHandler);
+ Assert.assertEquals(1, outboundChannel.outboundMessages().size());
+
+ // move internode protocol Msg1 to the server's channel
+ Object o;
+ while ((o = outboundChannel.readOutbound()) != null)
+ inboundChannel.writeInbound(o);
+ Assert.assertEquals(1, inboundChannel.outboundMessages().size());
+
+ // move internode protocol Msg2 to the client's channel
+ while ((o = inboundChannel.readOutbound()) != null)
+ outboundChannel.writeInbound(o);
+ Assert.assertEquals(1, outboundChannel.outboundMessages().size());
+
+ // move internode protocol Msg3 to the server's channel
+ while ((o = outboundChannel.readOutbound()) != null)
+ inboundChannel.writeInbound(o);
+
+ Assert.assertEquals(READY, imc.getState());
+ Assert.assertEquals(MESSAGING_HANDSHAKE_COMPLETE, inboundHandshakeHandler.getState());
+ }
+
+ @Test
+ public void lotsOfMutations_NoCompression() throws IOException
+ {
+ lotsOfMutations(false);
+ }
+
+ @Test
+ public void lotsOfMutations_WithCompression() throws IOException
+ {
+ lotsOfMutations(true);
+ }
+
+ private void lotsOfMutations(boolean compress)
+ {
+ TestChannels channels = buildChannels(compress);
+ EmbeddedChannel outboundChannel = channels.outboundChannel;
+ EmbeddedChannel inboundChannel = channels.inboundChannel;
+
+ // now the actual test!
+ ByteBuffer buf = ByteBuffer.allocate(1 << 10);
+ byte[] bytes = "ThisIsA16CharStr".getBytes();
+ while (buf.remaining() > 0)
+ buf.put(bytes);
+
+ // write a bunch of messages to the channel
+ ColumnFamilyStore cfs1 = Keyspace.open(KEYSPACE1).getColumnFamilyStore(STANDARD1);
+ int count = 1024;
+ for (int i = 0; i < count; i++)
+ {
+ if (i % 2 == 0)
+ {
+ Mutation mutation = new RowUpdateBuilder(cfs1.metadata.get(), 0, "k")
+ .clustering("bytes")
+ .add("val", buf)
+ .build();
+
+ QueuedMessage msg = new QueuedMessage(mutation.createMessage(), i);
+ outboundChannel.writeAndFlush(msg);
+ }
+ else
+ {
+ outboundChannel.writeAndFlush(new QueuedMessage(new MessageOut<>(MessagingService.Verb.ECHO), i));
+ }
+ }
+ outboundChannel.flush();
+
+ // move the messages to the other channel
+ Object o;
+ while ((o = outboundChannel.readOutbound()) != null)
+ inboundChannel.writeInbound(o);
+
+ Assert.assertTrue(outboundChannel.outboundMessages().isEmpty());
+ Assert.assertFalse(inboundChannel.finishAndReleaseAll());
+ }
+
+ private TestChannels buildChannels(boolean compress)
+ {
+ OutboundConnectionParams params = OutboundConnectionParams.builder()
+ .connectionId(connectionId)
+ .callback(this::nop)
+ .mode(NettyFactory.Mode.MESSAGING)
+ .compress(compress)
+ .coalescingStrategy(Optional.empty())
+ .protocolVersion(MessagingService.current_version)
+ .build();
+ OutboundHandshakeHandler outboundHandshakeHandler = new OutboundHandshakeHandler(params);
+ EmbeddedChannel outboundChannel = new EmbeddedChannel(outboundHandshakeHandler);
+ OutboundMessagingConnection omc = new OutboundMessagingConnection(connectionId, null, Optional.empty(), new AllowAllInternodeAuthenticator());
+ omc.setTargetVersion(MESSAGING_VERSION);
+ outboundHandshakeHandler.setupPipeline(outboundChannel, MESSAGING_VERSION);
+
+ // remove the outbound handshake message from the outbound messages
+ outboundChannel.outboundMessages().clear();
+
+ InboundHandshakeHandler handler = new InboundHandshakeHandler(new TestAuthenticator(true));
+ EmbeddedChannel inboundChannel = new EmbeddedChannel(handler);
+ handler.setupMessagingPipeline(inboundChannel.pipeline(), REMOTE_ADDR.getAddress(), compress, MESSAGING_VERSION);
+
+ return new TestChannels(outboundChannel, inboundChannel);
+ }
+
+ private static class TestChannels
+ {
+ final EmbeddedChannel outboundChannel;
+ final EmbeddedChannel inboundChannel;
+
+ TestChannels(EmbeddedChannel outboundChannel, EmbeddedChannel inboundChannel)
+ {
+ this.outboundChannel = outboundChannel;
+ this.inboundChannel = inboundChannel;
+ }
+ }
+
+ private Void nop(OutboundHandshakeHandler.HandshakeResult handshakeResult)
+ {
+ // do nothing, really
+ return null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java b/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java
new file mode 100644
index 0000000..a3d646d
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/HandshakeProtocolTest.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.net.async;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.PooledByteBufAllocator;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.net.MessagingService;
+import org.apache.cassandra.net.async.HandshakeProtocol.FirstHandshakeMessage;
+import org.apache.cassandra.net.async.HandshakeProtocol.SecondHandshakeMessage;
+import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage;
+import org.apache.cassandra.utils.FBUtilities;
+
+import static org.junit.Assert.assertEquals;
+
+public class HandshakeProtocolTest
+{
+ private ByteBuf buf;
+
+ @BeforeClass
+ public static void before()
+ {
+ // Kind of stupid, but the test trigger the initialization of the MessagingService class and that require
+ // DatabaseDescriptor to be configured ...
+ DatabaseDescriptor.daemonInitialization();
+ }
+
+ @After
+ public void tearDown()
+ {
+ if (buf != null && buf.refCnt() > 0)
+ buf.release();
+ }
+
+ @Test
+ public void firstMessageTest() throws Exception
+ {
+ firstMessageTest(NettyFactory.Mode.MESSAGING, false);
+ firstMessageTest(NettyFactory.Mode.MESSAGING, true);
+ firstMessageTest(NettyFactory.Mode.STREAMING, false);
+ firstMessageTest(NettyFactory.Mode.STREAMING, true);
+ }
+
+ private void firstMessageTest(NettyFactory.Mode mode, boolean compression) throws Exception
+ {
+ FirstHandshakeMessage before = new FirstHandshakeMessage(MessagingService.current_version, mode, compression);
+ buf = before.encode(PooledByteBufAllocator.DEFAULT);
+ FirstHandshakeMessage after = FirstHandshakeMessage.maybeDecode(buf);
+ assertEquals(before, after);
+ assertEquals(before.hashCode(), after.hashCode());
+ Assert.assertFalse(before.equals(null));
+ }
+
+ @Test
+ public void secondMessageTest() throws Exception
+ {
+ SecondHandshakeMessage before = new SecondHandshakeMessage(MessagingService.current_version);
+ buf = before.encode(PooledByteBufAllocator.DEFAULT);
+ SecondHandshakeMessage after = SecondHandshakeMessage.maybeDecode(buf);
+ assertEquals(before, after);
+ assertEquals(before.hashCode(), after.hashCode());
+ Assert.assertFalse(before.equals(null));
+ }
+
+ @Test
+ public void thirdMessageTest() throws Exception
+ {
+ ThirdHandshakeMessage before = new ThirdHandshakeMessage(MessagingService.current_version, FBUtilities.getBroadcastAddress());
+ buf = before.encode(PooledByteBufAllocator.DEFAULT);
+ ThirdHandshakeMessage after = ThirdHandshakeMessage.maybeDecode(buf);
+ assertEquals(before, after);
+ assertEquals(before.hashCode(), after.hashCode());
+ Assert.assertFalse(before.equals(null));
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java b/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java
new file mode 100644
index 0000000..44dc469
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/InboundHandshakeHandlerTest.java
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.async;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.ArrayList;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufOutputStream;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.compression.Lz4FrameDecoder;
+import io.netty.handler.codec.compression.Lz4FrameEncoder;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.net.CompactEndpointSerializationHelper;
+import org.apache.cassandra.net.MessagingService;
+import org.apache.cassandra.net.async.HandshakeProtocol.FirstHandshakeMessage;
+import org.apache.cassandra.net.async.HandshakeProtocol.ThirdHandshakeMessage;
+import org.apache.cassandra.net.async.InboundHandshakeHandler.State;
+
+import static org.apache.cassandra.net.async.NettyFactory.Mode.MESSAGING;
+
+public class InboundHandshakeHandlerTest
+{
+ private static final InetSocketAddress addr = new InetSocketAddress("127.0.0.1", 0);
+ private static final int MESSAGING_VERSION = MessagingService.current_version;
+ private static final int VERSION_30 = MessagingService.VERSION_30;
+
+ private InboundHandshakeHandler handler;
+ private EmbeddedChannel channel;
+ private ByteBuf buf;
+
+ @BeforeClass
+ public static void beforeClass()
+ {
+ DatabaseDescriptor.daemonInitialization();
+ }
+
+ @Before
+ public void setUp()
+ {
+ TestAuthenticator authenticator = new TestAuthenticator(false);
+ handler = new InboundHandshakeHandler(authenticator);
+ channel = new EmbeddedChannel(handler);
+ }
+
+ @After
+ public void tearDown()
+ {
+ if (buf != null)
+ buf.release();
+ channel.finishAndReleaseAll();
+ }
+
+ @Test
+ public void handleAuthenticate_Good()
+ {
+ handler = new InboundHandshakeHandler(new TestAuthenticator(true));
+ channel = new EmbeddedChannel(handler);
+ boolean result = handler.handleAuthenticate(addr, channel.pipeline().firstContext());
+ Assert.assertTrue(result);
+ Assert.assertTrue(channel.isOpen());
+ }
+
+ @Test
+ public void handleAuthenticate_Bad()
+ {
+ boolean result = handler.handleAuthenticate(addr, channel.pipeline().firstContext());
+ Assert.assertFalse(result);
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertFalse(channel.isActive());
+ }
+
+ @Test
+ public void handleAuthenticate_BadSocketAddr()
+ {
+ boolean result = handler.handleAuthenticate(new FakeSocketAddress(), channel.pipeline().firstContext());
+ Assert.assertFalse(result);
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertFalse(channel.isActive());
+ }
+
+ private static class FakeSocketAddress extends SocketAddress
+ { }
+
+ @Test
+ public void decode_AlreadyFailed()
+ {
+ handler.setState(State.HANDSHAKE_FAIL);
+ buf = new FirstHandshakeMessage(MESSAGING_VERSION, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT);
+ handler.decode(channel.pipeline().firstContext(), buf, new ArrayList<>());
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertFalse(channel.isActive());
+ Assert.assertSame(State.HANDSHAKE_FAIL, handler.getState());
+ }
+
+ @Test
+ public void handleStart_NotEnoughInputBytes() throws IOException
+ {
+ ByteBuf buf = Unpooled.EMPTY_BUFFER;
+ State state = handler.handleStart(channel.pipeline().firstContext(), buf);
+ Assert.assertEquals(State.START, state);
+ Assert.assertTrue(channel.isOpen());
+ Assert.assertTrue(channel.isActive());
+ }
+
+ @Test (expected = IOException.class)
+ public void handleStart_BadMagic() throws IOException
+ {
+ InboundHandshakeHandler handler = new InboundHandshakeHandler(new TestAuthenticator(false));
+ EmbeddedChannel channel = new EmbeddedChannel(handler);
+ buf = Unpooled.buffer(32, 32);
+
+ FirstHandshakeMessage first = new FirstHandshakeMessage(MESSAGING_VERSION,
+ MESSAGING,
+ true);
+
+ buf.writeInt(MessagingService.PROTOCOL_MAGIC << 2);
+ buf.writeInt(first.encodeFlags());
+ handler.handleStart(channel.pipeline().firstContext(), buf);
+ }
+
+ @Test
+ public void handleStart_VersionTooHigh() throws IOException
+ {
+ channel.eventLoop();
+ buf = new FirstHandshakeMessage(MESSAGING_VERSION + 1, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT);
+ State state = handler.handleStart(channel.pipeline().firstContext(), buf);
+ Assert.assertEquals(State.HANDSHAKE_FAIL, state);
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertFalse(channel.isActive());
+ }
+
+ @Test
+ public void handleStart_VersionLessThan3_0() throws IOException
+ {
+ buf = new FirstHandshakeMessage(VERSION_30 - 1, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT);
+ State state = handler.handleStart(channel.pipeline().firstContext(), buf);
+ Assert.assertEquals(State.HANDSHAKE_FAIL, state);
+
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertFalse(channel.isActive());
+ }
+
+ @Test
+ public void handleStart_HappyPath_Messaging() throws IOException
+ {
+ buf = new FirstHandshakeMessage(MESSAGING_VERSION, MESSAGING, true).encode(PooledByteBufAllocator.DEFAULT);
+ State state = handler.handleStart(channel.pipeline().firstContext(), buf);
+ Assert.assertEquals(State.AWAIT_MESSAGING_START_RESPONSE, state);
+ if (buf.refCnt() > 0)
+ buf.release();
+
+ buf = new ThirdHandshakeMessage(MESSAGING_VERSION, addr.getAddress()).encode(PooledByteBufAllocator.DEFAULT);
+ state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf);
+
+ Assert.assertEquals(State.MESSAGING_HANDSHAKE_COMPLETE, state);
+ Assert.assertTrue(channel.isOpen());
+ Assert.assertTrue(channel.isActive());
+ Assert.assertFalse(channel.outboundMessages().isEmpty());
+ channel.releaseOutbound();
+ }
+
+ @Test
+ public void handleMessagingStartResponse_NotEnoughInputBytes() throws IOException
+ {
+ ByteBuf buf = Unpooled.EMPTY_BUFFER;
+ State state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf);
+ Assert.assertEquals(State.AWAIT_MESSAGING_START_RESPONSE, state);
+ Assert.assertTrue(channel.isOpen());
+ Assert.assertTrue(channel.isActive());
+ }
+
+ @Test
+ public void handleMessagingStartResponse_BadMaxVersion() throws IOException
+ {
+ buf = Unpooled.buffer(32, 32);
+ buf.writeInt(MESSAGING_VERSION + 1);
+ CompactEndpointSerializationHelper.serialize(addr.getAddress(), new ByteBufOutputStream(buf));
+ State state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf);
+ Assert.assertEquals(State.HANDSHAKE_FAIL, state);
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertFalse(channel.isActive());
+ }
+
+ @Test
+ public void handleMessagingStartResponse_HappyPath() throws IOException
+ {
+ buf = Unpooled.buffer(32, 32);
+ buf.writeInt(MESSAGING_VERSION);
+ CompactEndpointSerializationHelper.serialize(addr.getAddress(), new ByteBufOutputStream(buf));
+ State state = handler.handleMessagingStartResponse(channel.pipeline().firstContext(), buf);
+ Assert.assertEquals(State.MESSAGING_HANDSHAKE_COMPLETE, state);
+ Assert.assertTrue(channel.isOpen());
+ Assert.assertTrue(channel.isActive());
+ }
+
+ @Test
+ public void setupPipeline_NoCompression()
+ {
+ ChannelPipeline pipeline = channel.pipeline();
+ Assert.assertNotNull(pipeline.get(InboundHandshakeHandler.class));
+
+ handler.setupMessagingPipeline(pipeline, addr.getAddress(), false, MESSAGING_VERSION);
+ Assert.assertNotNull(pipeline.get(MessageInHandler.class));
+ Assert.assertNull(pipeline.get(Lz4FrameDecoder.class));
+ Assert.assertNull(pipeline.get(Lz4FrameEncoder.class));
+ Assert.assertNull(pipeline.get(InboundHandshakeHandler.class));
+ }
+
+ @Test
+ public void setupPipeline_WithCompression()
+ {
+ ChannelPipeline pipeline = channel.pipeline();
+ Assert.assertNotNull(pipeline.get(InboundHandshakeHandler.class));
+
+ handler.setupMessagingPipeline(pipeline, addr.getAddress(), true, MESSAGING_VERSION);
+ Assert.assertNotNull(pipeline.get(MessageInHandler.class));
+ Assert.assertNotNull(pipeline.get(Lz4FrameDecoder.class));
+ Assert.assertNull(pipeline.get(Lz4FrameEncoder.class));
+ Assert.assertNull(pipeline.get(InboundHandshakeHandler.class));
+ }
+
+ @Test
+ public void failHandshake()
+ {
+ ChannelPromise future = channel.newPromise();
+ handler.setHandshakeTimeout(future);
+ Assert.assertFalse(future.isCancelled());
+ Assert.assertTrue(channel.isOpen());
+ handler.failHandshake(channel.pipeline().firstContext());
+ Assert.assertSame(State.HANDSHAKE_FAIL, handler.getState());
+ Assert.assertTrue(future.isCancelled());
+ Assert.assertFalse(channel.isOpen());
+ }
+
+ @Test
+ public void failHandshake_AlreadyConnected()
+ {
+ ChannelPromise future = channel.newPromise();
+ handler.setHandshakeTimeout(future);
+ Assert.assertFalse(future.isCancelled());
+ Assert.assertTrue(channel.isOpen());
+ handler.setState(State.MESSAGING_HANDSHAKE_COMPLETE);
+ handler.failHandshake(channel.pipeline().firstContext());
+ Assert.assertSame(State.MESSAGING_HANDSHAKE_COMPLETE, handler.getState());
+ Assert.assertTrue(channel.isOpen());
+ }
+
+ @Test
+ public void failHandshake_TaskIsCancelled()
+ {
+ ChannelPromise future = channel.newPromise();
+ future.cancel(false);
+ handler.setHandshakeTimeout(future);
+ handler.setState(State.AWAIT_MESSAGING_START_RESPONSE);
+ Assert.assertTrue(channel.isOpen());
+ handler.failHandshake(channel.pipeline().firstContext());
+ Assert.assertSame(State.AWAIT_MESSAGING_START_RESPONSE, handler.getState());
+ Assert.assertTrue(channel.isOpen());
+ }
+}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java b/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java
new file mode 100644
index 0000000..bb82d2c
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/MessageInHandlerTest.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.async;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiConsumer;
+
+import com.google.common.base.Charsets;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufOutputStream;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.util.concurrent.Future;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.net.MessageIn;
+import org.apache.cassandra.net.MessageOut;
+import org.apache.cassandra.net.MessagingService;
+import org.apache.cassandra.net.async.MessageInHandler.MessageHeader;
+import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.NanoTimeToCurrentTimeMillis;
+
+public class MessageInHandlerTest
+{
+ private static final InetSocketAddress addr = new InetSocketAddress("127.0.0.1", 0);
+ private static final int MSG_VERSION = MessagingService.current_version;
+
+ private static final int MSG_ID = 42;
+
+ private ByteBuf buf;
+
+ @BeforeClass
+ public static void before()
+ {
+ DatabaseDescriptor.daemonInitialization();
+ }
+
+ @After
+ public void tearDown()
+ {
+ if (buf != null && buf.refCnt() > 0)
+ buf.release();
+ }
+
+ @Test
+ public void decode_BadMagic() throws Exception
+ {
+ int len = MessageInHandler.FIRST_SECTION_BYTE_COUNT;
+ buf = Unpooled.buffer(len, len);
+ buf.writeInt(-1);
+ buf.writerIndex(len);
+
+ MessageInHandler handler = new MessageInHandler(addr.getAddress(), MSG_VERSION, null);
+ EmbeddedChannel channel = new EmbeddedChannel(handler);
+ Assert.assertTrue(channel.isOpen());
+ channel.writeInbound(buf);
+ Assert.assertFalse(channel.isOpen());
+ }
+
+ @Test
+ public void decode_HappyPath_NoParameters() throws Exception
+ {
+ MessageInWrapper result = decode_HappyPath(Collections.emptyMap());
+ Assert.assertTrue(result.messageIn.parameters.isEmpty());
+ }
+
+ @Test
+ public void decode_HappyPath_WithParameters() throws Exception
+ {
+ Map<String, byte[]> parameters = new HashMap<>();
+ parameters.put("p1", "val1".getBytes(Charsets.UTF_8));
+ parameters.put("p2", "val2".getBytes(Charsets.UTF_8));
+ MessageInWrapper result = decode_HappyPath(parameters);
+ Assert.assertEquals(2, result.messageIn.parameters.size());
+ }
+
+ private MessageInWrapper decode_HappyPath(Map<String, byte[]> parameters) throws Exception
+ {
+ MessageOut msgOut = new MessageOut(MessagingService.Verb.ECHO);
+ for (Map.Entry<String, byte[]> param : parameters.entrySet())
+ msgOut = msgOut.withParameter(param.getKey(), param.getValue());
+ serialize(msgOut);
+
+ MessageInWrapper wrapper = new MessageInWrapper();
+ MessageInHandler handler = new MessageInHandler(addr.getAddress(), MSG_VERSION, wrapper.messageConsumer);
+ List<Object> out = new ArrayList<>();
+ handler.decode(null, buf, out);
+
+ Assert.assertNotNull(wrapper.messageIn);
+ Assert.assertEquals(MSG_ID, wrapper.id);
+ Assert.assertEquals(msgOut.from, wrapper.messageIn.from);
+ Assert.assertEquals(msgOut.verb, wrapper.messageIn.verb);
+ Assert.assertTrue(out.isEmpty());
+
+ return wrapper;
+ }
+
+ private void serialize(MessageOut msgOut) throws IOException
+ {
+ buf = Unpooled.buffer(1024, 1024); // 1k should be enough for everybody!
+ buf.writeInt(MessagingService.PROTOCOL_MAGIC);
+ buf.writeInt(MSG_ID); // this is the id
+ buf.writeInt((int) NanoTimeToCurrentTimeMillis.convert(System.nanoTime()));
+
+ msgOut.serialize(new ByteBufDataOutputPlus(buf), MSG_VERSION);
+ }
+
+ @Test
+ public void decode_WithHalfReceivedParameters() throws Exception
+ {
+ MessageOut msgOut = new MessageOut(MessagingService.Verb.ECHO);
+ msgOut = msgOut.withParameter("p3", "val1".getBytes(Charsets.UTF_8));
+
+ serialize(msgOut);
+
+ // move the write index pointer back a few bytes to simulate like the full bytes are not present.
+ // yeah, it's lame, but it tests the basics of what is happening during the deserialiization
+ int originalWriterIndex = buf.writerIndex();
+ buf.writerIndex(originalWriterIndex - 6);
+
+ MessageInWrapper wrapper = new MessageInWrapper();
+ MessageInHandler handler = new MessageInHandler(addr.getAddress(), MSG_VERSION, wrapper.messageConsumer);
+ List<Object> out = new ArrayList<>();
+ handler.decode(null, buf, out);
+
+ Assert.assertNull(wrapper.messageIn);
+
+ MessageHeader header = handler.getMessageHeader();
+ Assert.assertEquals(MSG_ID, header.messageId);
+ Assert.assertEquals(msgOut.verb, header.verb);
+ Assert.assertEquals(msgOut.from, header.from);
+ Assert.assertTrue(out.isEmpty());
+
+ // now, set the writer index back to the original value to pretend that we actually got more bytes in
+ buf.writerIndex(originalWriterIndex);
+ handler.decode(null, buf, out);
+ Assert.assertNotNull(wrapper.messageIn);
+ Assert.assertTrue(out.isEmpty());
+ }
+
+ @Test
+ public void canReadNextParam_HappyPath() throws IOException
+ {
+ buildParamBuf(13);
+ Assert.assertTrue(MessageInHandler.canReadNextParam(buf));
+ }
+
+ @Test
+ public void canReadNextParam_OnlyFirstByte() throws IOException
+ {
+ buildParamBuf(13);
+ buf.writerIndex(1);
+ Assert.assertFalse(MessageInHandler.canReadNextParam(buf));
+ }
+
+ @Test
+ public void canReadNextParam_PartialUTF() throws IOException
+ {
+ buildParamBuf(13);
+ buf.writerIndex(5);
+ Assert.assertFalse(MessageInHandler.canReadNextParam(buf));
+ }
+
+ @Test
+ public void canReadNextParam_TruncatedValueLength() throws IOException
+ {
+ buildParamBuf(13);
+ buf.writerIndex(buf.writerIndex() - 13 - 2);
+ Assert.assertFalse(MessageInHandler.canReadNextParam(buf));
+ }
+
+ @Test
+ public void canReadNextParam_MissingLastBytes() throws IOException
+ {
+ buildParamBuf(13);
+ buf.writerIndex(buf.writerIndex() - 2);
+ Assert.assertFalse(MessageInHandler.canReadNextParam(buf));
+ }
+
+ private void buildParamBuf(int valueLength) throws IOException
+ {
+ buf = Unpooled.buffer(1024, 1024); // 1k should be enough for everybody!
+ ByteBufDataOutputPlus output = new ByteBufDataOutputPlus(buf);
+ output.writeUTF("name");
+ byte[] array = new byte[valueLength];
+ output.writeInt(array.length);
+ output.write(array);
+ }
+
+ @Test
+ public void exceptionHandled()
+ {
+ MessageInHandler handler = new MessageInHandler(addr.getAddress(), MSG_VERSION, null);
+ EmbeddedChannel channel = new EmbeddedChannel(handler);
+ Assert.assertTrue(channel.isOpen());
+ handler.exceptionCaught(channel.pipeline().firstContext(), new EOFException());
+ Assert.assertFalse(channel.isOpen());
+ }
+
+ private static class MessageInWrapper
+ {
+ MessageIn messageIn;
+ int id;
+
+ final BiConsumer<MessageIn, Integer> messageConsumer = (messageIn, integer) ->
+ {
+ this.messageIn = messageIn;
+ this.id = integer;
+ };
+ }
+}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java b/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java
new file mode 100644
index 0000000..566dfdb
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/MessageOutHandlerTest.java
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.async;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.HashMap;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import com.sun.org.apache.bcel.internal.generic.DDIV;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.DefaultChannelPromise;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.UnsupportedMessageTypeException;
+import io.netty.handler.timeout.IdleStateEvent;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.io.IVersionedSerializer;
+import org.apache.cassandra.io.util.DataInputPlus;
+import org.apache.cassandra.io.util.DataOutputPlus;
+import org.apache.cassandra.net.MessageOut;
+import org.apache.cassandra.net.MessagingService;
+import org.apache.cassandra.tracing.Tracing;
+import org.apache.cassandra.utils.UUIDGen;
+
+public class MessageOutHandlerTest
+{
+ private static final int MESSAGING_VERSION = MessagingService.current_version;
+
+ private ChannelWriter channelWriter;
+ private EmbeddedChannel channel;
+ private MessageOutHandler handler;
+
+ @BeforeClass
+ public static void before()
+ {
+ DatabaseDescriptor.daemonInitialization();
+ DatabaseDescriptor.createAllDirectories();
+ }
+
+ @Before
+ public void setup()
+ {
+ setup(MessageOutHandler.AUTO_FLUSH_THRESHOLD);
+ }
+
+ private void setup(int flushThreshold)
+ {
+ OutboundConnectionIdentifier connectionId = OutboundConnectionIdentifier.small(new InetSocketAddress("127.0.0.1", 0),
+ new InetSocketAddress("127.0.0.2", 0));
+ OutboundMessagingConnection omc = new NonSendingOutboundMessagingConnection(connectionId, null, Optional.empty());
+ channel = new EmbeddedChannel();
+ channelWriter = ChannelWriter.create(channel, omc::handleMessageResult, Optional.empty());
+ handler = new MessageOutHandler(connectionId, MESSAGING_VERSION, channelWriter, () -> null, flushThreshold);
+ channel.pipeline().addLast(handler);
+ }
+
+ @Test
+ public void write_NoFlush() throws ExecutionException, InterruptedException, TimeoutException
+ {
+ MessageOut message = new MessageOut(MessagingService.Verb.ECHO);
+ ChannelFuture future = channel.write(new QueuedMessage(message, 42));
+ Assert.assertTrue(!future.isDone());
+ Assert.assertFalse(channel.releaseOutbound());
+ }
+
+ @Test
+ public void write_WithFlush() throws ExecutionException, InterruptedException, TimeoutException
+ {
+ setup(1);
+ MessageOut message = new MessageOut(MessagingService.Verb.ECHO);
+ ChannelFuture future = channel.write(new QueuedMessage(message, 42));
+ Assert.assertTrue(future.isSuccess());
+ Assert.assertTrue(channel.releaseOutbound());
+ }
+
+ @Test
+ public void serializeMessage() throws IOException
+ {
+ channelWriter.pendingMessageCount.set(1);
+ QueuedMessage msg = new QueuedMessage(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), 1);
+ ChannelFuture future = channel.writeAndFlush(msg);
+
+ Assert.assertTrue(future.isSuccess());
+ Assert.assertTrue(1 <= channel.outboundMessages().size());
+ Assert.assertTrue(channel.releaseOutbound());
+ }
+
+ @Test
+ public void wrongMessageType()
+ {
+ ChannelPromise promise = new DefaultChannelPromise(channel);
+ Assert.assertFalse(handler.isMessageValid("this is the wrong message type", promise));
+
+ Assert.assertFalse(promise.isSuccess());
+ Assert.assertNotNull(promise.cause());
+ Assert.assertSame(UnsupportedMessageTypeException.class, promise.cause().getClass());
+ }
+
+ @Test
+ public void unexpiredMessage()
+ {
+ QueuedMessage msg = new QueuedMessage(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), 1);
+ ChannelPromise promise = new DefaultChannelPromise(channel);
+ Assert.assertTrue(handler.isMessageValid(msg, promise));
+
+ // we won't know if it was successful yet, but we'll know if it's a failure because cause will be set
+ Assert.assertNull(promise.cause());
+ }
+
+ @Test
+ public void expiredMessage()
+ {
+ QueuedMessage msg = new QueuedMessage(new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE), 1, 0, true, true);
+ ChannelPromise promise = new DefaultChannelPromise(channel);
+ Assert.assertFalse(handler.isMessageValid(msg, promise));
+
+ Assert.assertFalse(promise.isSuccess());
+ Assert.assertNotNull(promise.cause());
+ Assert.assertSame(ExpiredException.class, promise.cause().getClass());
+ Assert.assertTrue(channel.outboundMessages().isEmpty());
+ }
+
+ @Test
+ public void write_MessageTooLarge()
+ {
+ write_BadMessageSize(Integer.MAX_VALUE + 1);
+ }
+
+ @Test
+ public void write_MessageSizeIsBananas()
+ {
+ write_BadMessageSize(Integer.MIN_VALUE + 10000);
+ }
+
+ private void write_BadMessageSize(long size)
+ {
+ IVersionedSerializer<Object> serializer = new IVersionedSerializer<Object>()
+ {
+ public void serialize(Object o, DataOutputPlus out, int version)
+ { }
+
+ public Object deserialize(DataInputPlus in, int version)
+ {
+ return null;
+ }
+
+ public long serializedSize(Object o, int version)
+ {
+ return size;
+ }
+ };
+ MessageOut message = new MessageOut(MessagingService.Verb.UNUSED_5, "payload", serializer);
+ ChannelFuture future = channel.write(new QueuedMessage(message, 42));
+ Throwable t = future.cause();
+ Assert.assertNotNull(t);
+ Assert.assertSame(IllegalStateException.class, t.getClass());
+ Assert.assertTrue(channel.isOpen());
+ Assert.assertFalse(channel.releaseOutbound());
+ }
+
+ @Test
+ public void writeForceExceptionPath()
+ {
+ IVersionedSerializer<Object> serializer = new IVersionedSerializer<Object>()
+ {
+ public void serialize(Object o, DataOutputPlus out, int version)
+ {
+ throw new RuntimeException("this exception is part of the test - DON'T PANIC");
+ }
+
+ public Object deserialize(DataInputPlus in, int version)
+ {
+ return null;
+ }
+
+ public long serializedSize(Object o, int version)
+ {
+ return 42;
+ }
+ };
+ MessageOut message = new MessageOut(MessagingService.Verb.UNUSED_5, "payload", serializer);
+ ChannelFuture future = channel.write(new QueuedMessage(message, 42));
+ Throwable t = future.cause();
+ Assert.assertNotNull(t);
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertFalse(channel.releaseOutbound());
+ }
+
+ @Test
+ public void captureTracingInfo_ForceException()
+ {
+ MessageOut message = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE)
+ .withParameter(Tracing.TRACE_HEADER, new byte[9]);
+ handler.captureTracingInfo(new QueuedMessage(message, 42));
+ }
+
+ @Test
+ public void captureTracingInfo_UnknownSession()
+ {
+ UUID uuid = UUID.randomUUID();
+ MessageOut message = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE)
+ .withParameter(Tracing.TRACE_HEADER, UUIDGen.decompose(uuid));
+ handler.captureTracingInfo(new QueuedMessage(message, 42));
+ }
+
+ @Test
+ public void captureTracingInfo_KnownSession()
+ {
+ Tracing.instance.newSession(new HashMap<>());
+ MessageOut message = new MessageOut(MessagingService.Verb.REQUEST_RESPONSE);
+ handler.captureTracingInfo(new QueuedMessage(message, 42));
+ }
+
+ @Test
+ public void userEventTriggered_RandomObject()
+ {
+ Assert.assertTrue(channel.isOpen());
+ ChannelUserEventSender sender = new ChannelUserEventSender();
+ channel.pipeline().addFirst(sender);
+ sender.sendEvent("ThisIsAFakeEvent");
+ Assert.assertTrue(channel.isOpen());
+ }
+
+ @Test
+ public void userEventTriggered_Idle_NoPendingBytes()
+ {
+ Assert.assertTrue(channel.isOpen());
+ ChannelUserEventSender sender = new ChannelUserEventSender();
+ channel.pipeline().addFirst(sender);
+ sender.sendEvent(IdleStateEvent.WRITER_IDLE_STATE_EVENT);
+ Assert.assertTrue(channel.isOpen());
+ }
+
+ @Test
+ public void userEventTriggered_Idle_WithPendingBytes()
+ {
+ Assert.assertTrue(channel.isOpen());
+ ChannelUserEventSender sender = new ChannelUserEventSender();
+ channel.pipeline().addFirst(sender);
+
+ MessageOut message = new MessageOut(MessagingService.Verb.INTERNAL_RESPONSE);
+ channel.writeOutbound(new QueuedMessage(message, 42));
+ sender.sendEvent(IdleStateEvent.WRITER_IDLE_STATE_EVENT);
+ Assert.assertFalse(channel.isOpen());
+ }
+
+ private static class ChannelUserEventSender extends ChannelOutboundHandlerAdapter
+ {
+ private ChannelHandlerContext ctx;
+
+ @Override
+ public void handlerAdded(final ChannelHandlerContext ctx) throws Exception
+ {
+ this.ctx = ctx;
+ }
+
+ private void sendEvent(Object event)
+ {
+ ctx.fireUserEventTriggered(event);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java b/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java
new file mode 100644
index 0000000..c4cc7e6
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/NettyFactoryTest.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.async;
+
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.util.Optional;
+
+import com.google.common.net.InetAddresses;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.DefaultEventLoop;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.epoll.EpollServerSocketChannel;
+import io.netty.channel.group.ChannelGroup;
+import io.netty.channel.group.DefaultChannelGroup;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.util.concurrent.GlobalEventExecutor;
+import org.apache.cassandra.auth.AllowAllInternodeAuthenticator;
+import org.apache.cassandra.auth.IInternodeAuthenticator;
+import org.apache.cassandra.config.Config;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions;
+import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions.InternodeEncryption;
+import org.apache.cassandra.exceptions.ConfigurationException;
+import org.apache.cassandra.net.MessagingService;
+import org.apache.cassandra.net.async.NettyFactory.InboundInitializer;
+import org.apache.cassandra.net.async.NettyFactory.OutboundInitializer;
+import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.NativeLibrary;
+
+public class NettyFactoryTest
+{
+ private static final InetSocketAddress LOCAL_ADDR = new InetSocketAddress("127.0.0.1", 9876);
+ private static final InetSocketAddress REMOTE_ADDR = new InetSocketAddress("127.0.0.2", 9876);
+ private static final int receiveBufferSize = 1 << 16;
+ private static final IInternodeAuthenticator AUTHENTICATOR = new AllowAllInternodeAuthenticator();
+
+ private ChannelGroup channelGroup;
+ private NettyFactory factory;
+
+ @BeforeClass
+ public static void before()
+ {
+ DatabaseDescriptor.daemonInitialization();
+ }
+
+ @Before
+ public void setUp()
+ {
+ channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
+ }
+
+ @After
+ public void tearDown()
+ {
+ if (factory != null)
+ factory.close();
+ }
+
+ @Test
+ public void createServerChannel_Epoll()
+ {
+ Channel inboundChannel = createServerChannel(true);
+ if (inboundChannel == null)
+ return;
+ Assert.assertEquals(EpollServerSocketChannel.class, inboundChannel.getClass());
+ inboundChannel.close();
+ }
+
+ private Channel createServerChannel(boolean useEpoll)
+ {
+ InboundInitializer inboundInitializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup);
+ factory = new NettyFactory(useEpoll);
+
+ try
+ {
+ return factory.createInboundChannel(LOCAL_ADDR, inboundInitializer, receiveBufferSize);
+ }
+ catch (Exception e)
+ {
+ if (NativeLibrary.osType == NativeLibrary.OSType.LINUX)
+ throw e;
+
+ return null;
+ }
+ }
+
+ @Test
+ public void createServerChannel_Nio()
+ {
+ Channel inboundChannel = createServerChannel(false);
+ Assert.assertNotNull("we should always be able to get a NIO channel", inboundChannel);
+ Assert.assertEquals(NioServerSocketChannel.class, inboundChannel.getClass());
+ inboundChannel.close();
+ }
+
+ @Test(expected = ConfigurationException.class)
+ public void createServerChannel_SecondAttemptToBind()
+ {
+ Channel inboundChannel = null;
+ try
+ {
+ InetSocketAddress addr = new InetSocketAddress("127.0.0.1", 9876);
+ InboundInitializer inboundInitializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup);
+ inboundChannel = NettyFactory.instance.createInboundChannel(LOCAL_ADDR, inboundInitializer, receiveBufferSize);
+ NettyFactory.instance.createInboundChannel(LOCAL_ADDR, inboundInitializer, receiveBufferSize);
+ }
+ finally
+ {
+ if (inboundChannel != null)
+ inboundChannel.close();
+ }
+ }
+
+ @Test(expected = ConfigurationException.class)
+ public void createServerChannel_UnbindableAddress()
+ {
+ InetSocketAddress addr = new InetSocketAddress("1.1.1.1", 9876);
+ InboundInitializer inboundInitializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup);
+ NettyFactory.instance.createInboundChannel(addr, inboundInitializer, receiveBufferSize);
+ }
+
+ @Test
+ public void deterineAcceptGroupSize()
+ {
+ Assert.assertEquals(1, NettyFactory.determineAcceptGroupSize(InternodeEncryption.none));
+ Assert.assertEquals(1, NettyFactory.determineAcceptGroupSize(InternodeEncryption.all));
+ Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(InternodeEncryption.rack));
+ Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(InternodeEncryption.dc));
+
+ InetAddress originalBroadcastAddr = FBUtilities.getBroadcastAddress();
+ try
+ {
+ FBUtilities.setBroadcastInetAddress(InetAddresses.increment(FBUtilities.getLocalAddress()));
+ DatabaseDescriptor.setListenOnBroadcastAddress(true);
+
+ Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(InternodeEncryption.none));
+ Assert.assertEquals(2, NettyFactory.determineAcceptGroupSize(InternodeEncryption.all));
+ Assert.assertEquals(4, NettyFactory.determineAcceptGroupSize(InternodeEncryption.rack));
+ Assert.assertEquals(4, NettyFactory.determineAcceptGroupSize(InternodeEncryption.dc));
+ }
+ finally
+ {
+ FBUtilities.setBroadcastInetAddress(originalBroadcastAddr);
+ DatabaseDescriptor.setListenOnBroadcastAddress(false);
+ }
+ }
+
+ @Test
+ public void getEventLoopGroup_EpollWithIoRatioBoost()
+ {
+ getEventLoopGroup_Epoll(true);
+ }
+
+ private EpollEventLoopGroup getEventLoopGroup_Epoll(boolean ioBoost)
+ {
+ EventLoopGroup eventLoopGroup;
+ try
+ {
+ eventLoopGroup = NettyFactory.getEventLoopGroup(true, 1, "testEventLoopGroup", ioBoost);
+ }
+ catch (Exception e)
+ {
+ if (NativeLibrary.osType == NativeLibrary.OSType.LINUX)
+ throw e;
+
+ // ignore as epoll is only available on linux platforms, so don't fail the test on other OSes
+ return null;
+ }
+
+ Assert.assertTrue(eventLoopGroup instanceof EpollEventLoopGroup);
+ return (EpollEventLoopGroup) eventLoopGroup;
+ }
+
+ @Test
+ public void getEventLoopGroup_EpollWithoutIoRatioBoost()
+ {
+ getEventLoopGroup_Epoll(false);
+ }
+
+ @Test
+ public void getEventLoopGroup_NioWithoutIoRatioBoost()
+ {
+ getEventLoopGroup_Nio(true);
+ }
+
+ private NioEventLoopGroup getEventLoopGroup_Nio(boolean ioBoost)
+ {
+ EventLoopGroup eventLoopGroup = NettyFactory.getEventLoopGroup(false, 1, "testEventLoopGroup", ioBoost);
+ Assert.assertTrue(eventLoopGroup instanceof NioEventLoopGroup);
+ return (NioEventLoopGroup) eventLoopGroup;
+ }
+
+ @Test
+ public void getEventLoopGroup_NioWithIoRatioBoost()
+ {
+ getEventLoopGroup_Nio(true);
+ }
+
+ @Test
+ public void createOutboundBootstrap_Epoll()
+ {
+ Bootstrap bootstrap = createOutboundBootstrap(true);
+ Assert.assertEquals(EpollEventLoopGroup.class, bootstrap.config().group().getClass());
+ }
+
+ private Bootstrap createOutboundBootstrap(boolean useEpoll)
+ {
+ factory = new NettyFactory(useEpoll);
+ OutboundConnectionIdentifier id = OutboundConnectionIdentifier.gossip(LOCAL_ADDR, REMOTE_ADDR);
+ OutboundConnectionParams params = OutboundConnectionParams.builder()
+ .connectionId(id)
+ .coalescingStrategy(Optional.empty())
+ .protocolVersion(MessagingService.current_version)
+ .build();
+ return factory.createOutboundBootstrap(params);
+ }
+
+ @Test
+ public void createOutboundBootstrap_Nio()
+ {
+ Bootstrap bootstrap = createOutboundBootstrap(false);
+ Assert.assertEquals(NioEventLoopGroup.class, bootstrap.config().group().getClass());
+ }
+
+ @Test
+ public void createInboundInitializer_WithoutSsl() throws Exception
+ {
+ InboundInitializer initializer = new InboundInitializer(AUTHENTICATOR, null, channelGroup);
+ NioSocketChannel channel = new NioSocketChannel();
+ initializer.initChannel(channel);
+ Assert.assertNull(channel.pipeline().get(SslHandler.class));
+ }
+
+ private ServerEncryptionOptions encOptions()
+ {
+ ServerEncryptionOptions encryptionOptions;
+ encryptionOptions = new ServerEncryptionOptions();
+ encryptionOptions.keystore = "test/conf/cassandra_ssl_test.keystore";
+ encryptionOptions.keystore_password = "cassandra";
+ encryptionOptions.truststore = "test/conf/cassandra_ssl_test.truststore";
+ encryptionOptions.truststore_password = "cassandra";
+ encryptionOptions.require_client_auth = false;
+ encryptionOptions.cipher_suites = new String[] {"TLS_RSA_WITH_AES_128_CBC_SHA"};
+ return encryptionOptions;
+ }
+ @Test
+ public void createInboundInitializer_WithSsl() throws Exception
+ {
+ ServerEncryptionOptions encryptionOptions = encOptions();
+ InboundInitializer initializer = new InboundInitializer(AUTHENTICATOR, encryptionOptions, channelGroup);
+ NioSocketChannel channel = new NioSocketChannel();
+ Assert.assertNull(channel.pipeline().get(SslHandler.class));
+ initializer.initChannel(channel);
+ Assert.assertNotNull(channel.pipeline().get(SslHandler.class));
+ }
+
+ @Test
+ public void createOutboundInitializer_WithSsl() throws Exception
+ {
+ OutboundConnectionIdentifier id = OutboundConnectionIdentifier.gossip(LOCAL_ADDR, REMOTE_ADDR);
+ OutboundConnectionParams params = OutboundConnectionParams.builder()
+ .connectionId(id)
+ .encryptionOptions(encOptions())
+ .protocolVersion(MessagingService.current_version)
+ .build();
+ OutboundInitializer outboundInitializer = new OutboundInitializer(params);
+ NioSocketChannel channel = new NioSocketChannel();
+ Assert.assertNull(channel.pipeline().get(SslHandler.class));
+ outboundInitializer.initChannel(channel);
+ Assert.assertNotNull(channel.pipeline().get(SslHandler.class));
+ }
+}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java b/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java
new file mode 100644
index 0000000..b0b15b8
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/NonSendingOutboundMessagingConnection.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.async;
+
+import java.util.Optional;
+
+import org.apache.cassandra.auth.AllowAllInternodeAuthenticator;
+import org.apache.cassandra.config.EncryptionOptions;
+import org.apache.cassandra.utils.CoalescingStrategies;
+
+class NonSendingOutboundMessagingConnection extends OutboundMessagingConnection
+{
+ boolean sendMessageInvoked;
+
+ NonSendingOutboundMessagingConnection(OutboundConnectionIdentifier connectionId, EncryptionOptions.ServerEncryptionOptions encryptionOptions, Optional<CoalescingStrategies.CoalescingStrategy> coalescingStrategy)
+ {
+ super(connectionId, encryptionOptions, coalescingStrategy, new AllowAllInternodeAuthenticator());
+ }
+
+ @Override
+ boolean sendMessage(QueuedMessage queuedMessage)
+ {
+ sendMessageInvoked = true;
+ return true;
+ }
+}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java b/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java
new file mode 100644
index 0000000..0ce4968
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/OutboundConnectionParamsTest.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.async;
+
+import org.junit.Test;
+
+public class OutboundConnectionParamsTest
+{
+ @Test (expected = IllegalArgumentException.class)
+ public void build_SendSizeLessThanZero()
+ {
+ OutboundConnectionParams.builder().sendBufferSize(-1).build();
+ }
+
+ @Test (expected = IllegalArgumentException.class)
+ public void build_SendSizeHuge()
+ {
+ OutboundConnectionParams.builder().sendBufferSize(1 << 30).build();
+ }
+}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/356dc3c2/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java b/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java
new file mode 100644
index 0000000..f8bfab1
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/async/OutboundHandshakeHandlerTest.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.async;
+
+import java.net.InetSocketAddress;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Optional;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.compression.Lz4FrameDecoder;
+import io.netty.handler.codec.compression.Lz4FrameEncoder;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.net.MessagingService;
+import org.apache.cassandra.net.async.HandshakeProtocol.SecondHandshakeMessage;
+import org.apache.cassandra.net.async.OutboundHandshakeHandler.HandshakeResult;
+
+import static org.apache.cassandra.net.async.OutboundHandshakeHandler.HandshakeResult.UNKNOWN_PROTOCOL_VERSION;
+
+public class OutboundHandshakeHandlerTest
+{
+ private static final int MESSAGING_VERSION = MessagingService.current_version;
+ private static final InetSocketAddress localAddr = new InetSocketAddress("127.0.0.1", 0);
+ private static final InetSocketAddress remoteAddr = new InetSocketAddress("127.0.0.2", 0);
+ private static final String HANDLER_NAME = "clientHandshakeHandler";
+
+ private EmbeddedChannel channel;
+ private OutboundHandshakeHandler handler;
+ private OutboundConnectionIdentifier connectionId;
+ private OutboundConnectionParams params;
+ private CallbackHandler callbackHandler;
+ private ByteBuf buf;
+
+ @BeforeClass
+ public static void before()
+ {
+ DatabaseDescriptor.daemonInitialization();
+ }
+
+ @Before
+ public void setup()
+ {
+ channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter());
+ connectionId = OutboundConnectionIdentifier.small(localAddr, remoteAddr);
+ callbackHandler = new CallbackHandler();
+ params = OutboundConnectionParams.builder()
+ .connectionId(connectionId)
+ .callback(handshakeResult -> callbackHandler.receive(handshakeResult))
+ .mode(NettyFactory.Mode.MESSAGING)
+ .protocolVersion(MessagingService.current_version)
+ .coalescingStrategy(Optional.empty())
+ .build();
+ handler = new OutboundHandshakeHandler(params);
+ channel.pipeline().addFirst(HANDLER_NAME, handler);
+ }
+
+ @After
+ public void tearDown()
+ {
+ if (buf != null && buf.refCnt() > 0)
+ buf.release();
+ Assert.assertFalse(channel.finishAndReleaseAll());
+ }
+
+ @Test
+ public void decode_SmallInput() throws Exception
+ {
+ buf = Unpooled.buffer(2, 2);
+ List<Object> out = new LinkedList<>();
+ handler.decode(channel.pipeline().firstContext(), buf, out);
+ Assert.assertEquals(0, buf.readerIndex());
+ Assert.assertTrue(out.isEmpty());
+ }
+
+ @Test
+ public void decode_HappyPath() throws Exception
+ {
+ buf = new SecondHandshakeMessage(MESSAGING_VERSION).encode(PooledByteBufAllocator.DEFAULT);
+ channel.writeInbound(buf);
+ Assert.assertEquals(1, channel.outboundMessages().size());
+ Assert.assertTrue(channel.isOpen());
+ Assert.assertTrue(channel.releaseOutbound()); // throw away any responses from decode()
+
+ Assert.assertEquals(MESSAGING_VERSION, callbackHandler.result.negotiatedMessagingVersion);
+ Assert.assertEquals(HandshakeResult.Outcome.SUCCESS, callbackHandler.result.outcome);
+ }
+
+ @Test
+ public void decode_HappyPathThrowsException() throws Exception
+ {
+ callbackHandler.failOnCallback = true;
+ buf = new SecondHandshakeMessage(MESSAGING_VERSION).encode(PooledByteBufAllocator.DEFAULT);
+ channel.writeInbound(buf);
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertEquals(1, channel.outboundMessages().size());
+ Assert.assertTrue(channel.releaseOutbound()); // throw away any responses from decode()
+
+ Assert.assertEquals(UNKNOWN_PROTOCOL_VERSION, callbackHandler.result.negotiatedMessagingVersion);
+ Assert.assertEquals(HandshakeResult.Outcome.NEGOTIATION_FAILURE, callbackHandler.result.outcome);
+ }
+
+ @Test
+ public void decode_ReceivedLowerMsgVersion() throws Exception
+ {
+ int msgVersion = MESSAGING_VERSION - 1;
+ buf = new SecondHandshakeMessage(msgVersion).encode(PooledByteBufAllocator.DEFAULT);
+ channel.writeInbound(buf);
+ Assert.assertTrue(channel.inboundMessages().isEmpty());
+
+ Assert.assertEquals(msgVersion, callbackHandler.result.negotiatedMessagingVersion);
+ Assert.assertEquals(HandshakeResult.Outcome.DISCONNECT, callbackHandler.result.outcome);
+ Assert.assertFalse(channel.isOpen());
+ Assert.assertTrue(channel.outboundMessages().isEmpty());
+ }
+
+ @Test
+ public void decode_ReceivedHigherMsgVersion() throws Exception
+ {
+ int msgVersion = MESSAGING_VERSION - 1;
+ channel.pipeline().remove(HANDLER_NAME);
+ params = OutboundConnectionParams.builder()
+ .connectionId(connectionId)
+ .callback(handshakeResult -> callbackHandler.receive(handshakeResult))
+ .mode(NettyFactory.Mode.MESSAGING)
+ .protocolVersion(msgVersion)
+ .coalescingStrategy(Optional.empty())
+ .build();
+ handler = new OutboundHandshakeHandler(params);
+ channel.pipeline().addFirst(HANDLER_NAME, handler);
+ buf = new SecondHandshakeMessage(MESSAGING_VERSION).encode(PooledByteBufAllocator.DEFAULT);
+ channel.writeInbound(buf);
+
+ Assert.assertEquals(MESSAGING_VERSION, callbackHandler.result.negotiatedMessagingVersion);
+ Assert.assertEquals(HandshakeResult.Outcome.DISCONNECT, callbackHandler.result.outcome);
+ }
+
+ @Test
+ public void setupPipeline_WithCompression()
+ {
+ EmbeddedChannel chan = new EmbeddedChannel(new ChannelOutboundHandlerAdapter());
+ ChannelPipeline pipeline = chan.pipeline();
+ params = OutboundConnectionParams.builder(params).compress(true).protocolVersion(MessagingService.current_version).build();
+ handler = new OutboundHandshakeHandler(params);
+ pipeline.addFirst(handler);
+ handler.setupPipeline(chan, MESSAGING_VERSION);
+ Assert.assertNotNull(pipeline.get(Lz4FrameEncoder.class));
+ Assert.assertNull(pipeline.get(Lz4FrameDecoder.class));
+ Assert.assertNotNull(pipeline.get(MessageOutHandler.class));
+ }
+
+ @Test
+ public void setupPipeline_NoCompression()
+ {
+ EmbeddedChannel chan = new EmbeddedChannel(new ChannelOutboundHandlerAdapter());
+ ChannelPipeline pipeline = chan.pipeline();
+ params = OutboundConnectionParams.builder(params).compress(false).protocolVersion(MessagingService.current_version).build();
+ handler = new OutboundHandshakeHandler(params);
+ pipeline.addFirst(handler);
+ handler.setupPipeline(chan, MESSAGING_VERSION);
+ Assert.assertNull(pipeline.get(Lz4FrameEncoder.class));
+ Assert.assertNull(pipeline.get(Lz4FrameDecoder.class));
+ Assert.assertNotNull(pipeline.get(MessageOutHandler.class));
+ }
+
+ private static class CallbackHandler
+ {
+ boolean failOnCallback;
+ HandshakeResult result;
+
+ Void receive(HandshakeResult handshakeResult)
+ {
+ if (failOnCallback)
+ {
+ // only fail the first callback
+ failOnCallback = false;
+ throw new RuntimeException("this exception is expected in the test - DON'T PANIC");
+ }
+ result = handshakeResult;
+ return null;
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org