You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by al...@apache.org on 2019/06/12 14:54:53 UTC
[cassandra] 02/02: Introduce a proxy test handler,
extra unit tests for connection closure and message expirations
This is an automated email from the ASF dual-hosted git repository.
aleksey pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git
commit aed15bfb01e753f9bac8b149b382c7a7c8d33183
Author: Alex Petrov <ol...@gmail.com>
AuthorDate: Wed Jun 12 15:02:37 2019 +0100
Introduce a proxy test handler, extra unit tests for connection closure
and message expirations
patch by Alex Petrov; reviewed by Aleksey Yeschenko and Benedict Elliott
Smith for CASSANDRA-15066
---
.../cassandra/net/ProxyHandlerConnectionsTest.java | 405 +++++++++++++++++++++
.../cassandra/net/proxy/InboundProxyHandler.java | 234 ++++++++++++
.../cassandra/net/proxy/ProxyHandlerTest.java | 222 +++++++++++
3 files changed, 861 insertions(+)
diff --git a/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java b/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java
new file mode 100644
index 0000000..ae41cce
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Supplier;
+import java.util.function.ToLongFunction;
+
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.io.IVersionedAsymmetricSerializer;
+import org.apache.cassandra.io.IVersionedSerializer;
+import org.apache.cassandra.io.util.DataInputPlus;
+import org.apache.cassandra.io.util.DataOutputPlus;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.net.proxy.InboundProxyHandler;
+import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.Pair;
+
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.apache.cassandra.net.ConnectionTest.SETTINGS;
+import static org.apache.cassandra.net.OutboundConnectionSettings.Framing.CRC;
+import static org.apache.cassandra.utils.MonotonicClock.approxTime;
+
+public class ProxyHandlerConnectionsTest
+{
+ private static final SocketFactory factory = new SocketFactory();
+
+ private final Map<Verb, Supplier<? extends IVersionedAsymmetricSerializer<?, ?>>> serializers = new HashMap<>();
+ private final Map<Verb, Supplier<? extends IVerbHandler<?>>> handlers = new HashMap<>();
+ private final Map<Verb, ToLongFunction<TimeUnit>> timeouts = new HashMap<>();
+
+ private void unsafeSetSerializer(Verb verb, Supplier<? extends IVersionedAsymmetricSerializer<?, ?>> supplier) throws Throwable
+ {
+ serializers.putIfAbsent(verb, verb.unsafeSetSerializer(supplier));
+ }
+
+ protected void unsafeSetHandler(Verb verb, Supplier<? extends IVerbHandler<?>> supplier) throws Throwable
+ {
+ handlers.putIfAbsent(verb, verb.unsafeSetHandler(supplier));
+ }
+
+ private void unsafeSetExpiration(Verb verb, ToLongFunction<TimeUnit> expiration) throws Throwable
+ {
+ timeouts.putIfAbsent(verb, verb.unsafeSetExpiration(expiration));
+ }
+
+ @BeforeClass
+ public static void startup()
+ {
+ DatabaseDescriptor.daemonInitialization();
+ }
+
+ @After
+ public void cleanup() throws Throwable
+ {
+ for (Map.Entry<Verb, Supplier<? extends IVersionedAsymmetricSerializer<?, ?>>> e : serializers.entrySet())
+ e.getKey().unsafeSetSerializer(e.getValue());
+ serializers.clear();
+ for (Map.Entry<Verb, Supplier<? extends IVerbHandler<?>>> e : handlers.entrySet())
+ e.getKey().unsafeSetHandler(e.getValue());
+ handlers.clear();
+ for (Map.Entry<Verb, ToLongFunction<TimeUnit>> e : timeouts.entrySet())
+ e.getKey().unsafeSetExpiration(e.getValue());
+ timeouts.clear();
+ }
+
+ @Test
+ public void testExpireInbound() throws Throwable
+ {
+ DatabaseDescriptor.setCrossNodeTimeout(true);
+ testOneManual((settings, inbound, outbound, endpoint, handler) -> {
+ unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+
+ CountDownLatch connectionLatch = new CountDownLatch(1);
+ unsafeSetHandler(Verb._TEST_1, () -> v -> {
+ connectionLatch.countDown();
+ });
+ outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+ connectionLatch.await(10, SECONDS);
+ Assert.assertEquals(0, connectionLatch.getCount());
+
+ // Slow things down
+ unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(50, MILLISECONDS));
+ handler.withLatency(100, MILLISECONDS);
+
+ unsafeSetHandler(Verb._TEST_1, () -> v -> {
+ throw new RuntimeException("Should have not been triggered " + v);
+ });
+ int expireMessages = 10;
+ for (int i = 0; i < expireMessages; i++)
+ outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+
+ InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint);
+ waitForCondition(() -> handlers.expiredCount() == expireMessages);
+ Assert.assertEquals(expireMessages, handlers.expiredCount());
+ });
+ }
+
+ @Test
+ public void testExpireSome() throws Throwable
+ {
+ DatabaseDescriptor.setCrossNodeTimeout(true);
+ testOneManual((settings, inbound, outbound, endpoint, handler) -> {
+ unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+ connect(outbound);
+
+ AtomicInteger counter = new AtomicInteger();
+ unsafeSetHandler(Verb._TEST_1, () -> v -> {
+ counter.incrementAndGet();
+ });
+
+ int expireMessages = 10;
+ for (int i = 0; i < expireMessages; i++)
+ outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+ waitForCondition(() -> counter.get() == 10);
+
+ unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(50, MILLISECONDS));
+ handler.withLatency(100, MILLISECONDS);
+
+ InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint);
+ for (int i = 0; i < expireMessages; i++)
+ outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+ waitForCondition(() -> handlers.expiredCount() == 10);
+
+ handler.withLatency(2, MILLISECONDS);
+
+ for (int i = 0; i < expireMessages; i++)
+ outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+ waitForCondition(() -> counter.get() == 20);
+ });
+ }
+
+ @Test
+ public void testExpireSomeFromBatch() throws Throwable
+ {
+ DatabaseDescriptor.setCrossNodeTimeout(true);
+ testManual((settings, inbound, outbound, endpoint, handler) -> {
+ unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+ connect(outbound);
+
+ Message msg = Message.out(Verb._TEST_1, 1L);
+ int messageSize = msg.serializedSize(MessagingService.current_version);
+ DatabaseDescriptor.setInternodeMaxMessageSizeInBytes(messageSize * 40);
+
+ AtomicInteger counter = new AtomicInteger();
+ unsafeSetHandler(Verb._TEST_1, () -> v -> {
+ counter.incrementAndGet();
+ });
+
+ unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(200, MILLISECONDS));
+ handler.withLatency(100, MILLISECONDS);
+
+ int expireMessages = 20;
+ long nanoTime = approxTime.now();
+ CountDownLatch enqueueDone = new CountDownLatch(1);
+ outbound.unsafeRunOnDelivery(() -> Uninterruptibles.awaitUninterruptibly(enqueueDone, 10, SECONDS));
+ for (int i = 0; i < expireMessages; i++)
+ {
+ boolean expire = i % 2 == 0;
+ Message.Builder builder = Message.builder(Verb._TEST_1, 1L);
+
+ if (settings.right.acceptVersions == ConnectionTest.legacy)
+ {
+ // backdate messages; leave 50 milliseconds to leave outbound path
+ builder.withCreatedAt(nanoTime - (expire ? 0 : MILLISECONDS.toNanos(150)));
+ }
+ else
+ {
+ // Give messages 50 milliseconds to leave outbound path
+ builder.withCreatedAt(nanoTime)
+ .withExpiresAt(nanoTime + (expire ? MILLISECONDS.toNanos(50) : MILLISECONDS.toNanos(1000)));
+ }
+ outbound.enqueue(builder.build());
+ }
+ enqueueDone.countDown();
+
+ InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint);
+ waitForCondition(() -> handlers.expiredCount() == 10 && counter.get() == 10,
+ () -> String.format("Expired: %d, Arrived: %d", handlers.expiredCount(), counter.get()));
+ });
+ }
+
+ @Test
+ public void suddenDisconnect() throws Throwable
+ {
+ testManual((settings, inbound, outbound, endpoint, handler) -> {
+ handler.onDisconnect(() -> handler.reset());
+
+ unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+ connect(outbound);
+
+ CountDownLatch closeLatch = new CountDownLatch(1);
+ handler.withCloseAfterRead(closeLatch::countDown);
+ AtomicInteger counter = new AtomicInteger();
+ unsafeSetHandler(Verb._TEST_1, () -> v -> counter.incrementAndGet());
+
+ outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+ waitForCondition(() -> !outbound.isConnected());
+
+ connect(outbound);
+ Assert.assertTrue(outbound.isConnected());
+ Assert.assertEquals(0, counter.get());
+ });
+ }
+
+ @Test
+ public void testCorruptionOnHandshake() throws Throwable
+ {
+ testManual((settings, inbound, outbound, endpoint, handler) -> {
+ unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+ // Invalid CRC
+ handler.withPayloadTransform(msg -> {
+ ByteBuf bb = (ByteBuf) msg;
+ bb.setByte(bb.readableBytes() / 2, 0xffff);
+ return msg;
+ });
+ tryConnect(outbound, 1, SECONDS, false);
+ Assert.assertTrue(!outbound.isConnected());
+
+ // Invalid protocol magic
+ handler.withPayloadTransform(msg -> {
+ ByteBuf bb = (ByteBuf) msg;
+ bb.setByte(0, 0xffff);
+ return msg;
+ });
+ tryConnect(outbound, 1, SECONDS, false);
+ Assert.assertTrue(!outbound.isConnected());
+ if (settings.right.framing == CRC)
+ {
+ Assert.assertEquals(2, outbound.connectionAttempts());
+ Assert.assertEquals(0, outbound.successfulConnections());
+ }
+ });
+ }
+
+ private static void waitForCondition(Supplier<Boolean> cond) throws Throwable
+ {
+ CompletableFuture.runAsync(() -> {
+ while (!cond.get()) {}
+ }).get(10, SECONDS);
+ }
+
+ private static void waitForCondition(Supplier<Boolean> cond, Supplier<String> s) throws Throwable
+ {
+ try
+ {
+ CompletableFuture.runAsync(() -> {
+ while (!cond.get()) {}
+ }).get(10, SECONDS);
+ }
+ catch (TimeoutException e)
+ {
+ throw new AssertionError(s.get());
+ }
+ }
+
+ private static class FakePayloadSerializer implements IVersionedSerializer<Long>
+ {
+ private final int size;
+ private FakePayloadSerializer()
+ {
+ this(1);
+ }
+
+ // Takes long and repeats it size times
+ private FakePayloadSerializer(int size)
+ {
+ this.size = size;
+ }
+
+ public void serialize(Long i, DataOutputPlus out, int version) throws IOException
+ {
+ for (int j = 0; j < size; j++)
+ {
+ out.writeLong(i);
+ }
+ }
+
+ public Long deserialize(DataInputPlus in, int version) throws IOException
+ {
+ long l = in.readLong();
+ for (int i = 0; i < size - 1; i++)
+ {
+ if (in.readLong() != l)
+ throw new AssertionError();
+ }
+
+ return l;
+ }
+
+ public long serializedSize(Long t, int version)
+ {
+ return Long.BYTES * size;
+ }
+ }
+ interface ManualSendTest
+ {
+ void accept(Pair<InboundConnectionSettings, OutboundConnectionSettings> settings, InboundSockets inbound, OutboundConnection outbound, InetAddressAndPort endpoint, InboundProxyHandler.Controller handler) throws Throwable;
+ }
+
+ private void testManual(ManualSendTest test) throws Throwable
+ {
+ for (ConnectionTest.Settings s: SETTINGS)
+ {
+ doTestManual(s, test);
+ cleanup();
+ }
+ }
+
+ private void testOneManual(ManualSendTest test) throws Throwable
+ {
+ testOneManual(test, 1);
+ }
+
+ private void testOneManual(ManualSendTest test, int i) throws Throwable
+ {
+ ConnectionTest.Settings s = SETTINGS.get(i);
+ doTestManual(s, test);
+ cleanup();
+ }
+
+ private void doTestManual(ConnectionTest.Settings settings, ManualSendTest test) throws Throwable
+ {
+ InetAddressAndPort endpoint = FBUtilities.getBroadcastAddressAndPort();
+
+ InboundConnectionSettings inboundSettings = settings.inbound.apply(new InboundConnectionSettings())
+ .withBindAddress(endpoint)
+ .withSocketFactory(factory);
+
+ InboundSockets inbound = new InboundSockets(Collections.singletonList(inboundSettings));
+
+ OutboundConnectionSettings outboundSettings = settings.outbound.apply(new OutboundConnectionSettings(endpoint))
+ .withConnectTo(endpoint)
+ .withDefaultReserveLimits()
+ .withSocketFactory(factory);
+
+ ResourceLimits.EndpointAndGlobal reserveCapacityInBytes = new ResourceLimits.EndpointAndGlobal(new ResourceLimits.Concurrent(outboundSettings.applicationReserveSendQueueEndpointCapacityInBytes), outboundSettings.applicationReserveSendQueueGlobalCapacityInBytes);
+ OutboundConnection outbound = new OutboundConnection(settings.type, outboundSettings, reserveCapacityInBytes);
+ try
+ {
+ InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller();
+ inbound.open(pipeline -> {
+ InboundProxyHandler handler = new InboundProxyHandler(controller);
+ pipeline.addLast(handler);
+ }).sync();
+ test.accept(Pair.create(inboundSettings, outboundSettings), inbound, outbound, endpoint, controller);
+ }
+ finally
+ {
+ outbound.close(false);
+ inbound.close().get(30L, SECONDS);
+ outbound.close(false).get(30L, SECONDS);
+ MessagingService.instance().messageHandlers.clear();
+ }
+ }
+
+ private void connect(OutboundConnection outbound) throws Throwable
+ {
+ tryConnect(outbound, 10, SECONDS, true);
+ }
+
+ private void tryConnect(OutboundConnection outbound, long timeout, TimeUnit timeUnit, boolean throwOnFailure) throws Throwable
+ {
+ CountDownLatch connectionLatch = new CountDownLatch(1);
+ unsafeSetHandler(Verb._TEST_1, () -> v -> {
+ connectionLatch.countDown();
+ });
+ outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+ connectionLatch.await(timeout, timeUnit);
+ if (throwOnFailure)
+ Assert.assertEquals(0, connectionLatch.getCount());
+ }
+}
diff --git a/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java b/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java
new file mode 100644
index 0000000..7e3b004
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.proxy;
+
+import java.util.ArrayDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.util.concurrent.EventExecutor;
+import io.netty.util.concurrent.ScheduledFuture;
+
+public class InboundProxyHandler extends ChannelInboundHandlerAdapter
+{
+ private final ArrayDeque<Forward> forwardQueue;
+ private ScheduledFuture scheduled = null;
+ private final Controller controller;
+ public InboundProxyHandler(Controller controller)
+ {
+ this.controller = controller;
+ this.forwardQueue = new ArrayDeque<>(1024);
+ }
+
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) throws Exception
+ {
+ super.channelActive(ctx);
+ ctx.read();
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ controller.onDisconnect.run();
+
+ if (scheduled != null)
+ {
+ scheduled.cancel(true);
+ scheduled = null;
+ }
+
+ if (!forwardQueue.isEmpty())
+ forwardQueue.clear();
+
+ super.channelInactive(ctx);
+ }
+
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg)
+ {
+ Forward forward = controller.forwardStrategy.forward(ctx, msg);
+ forwardQueue.offer(forward);
+ maybeScheduleNext(ctx.channel().eventLoop());
+ controller.onRead.run();
+ ctx.channel().read();
+ }
+
+ private void maybeScheduleNext(EventExecutor executor)
+ {
+ if (forwardQueue.isEmpty())
+ {
+ // Ran out of items to process
+ scheduled = null;
+ }
+ else if (scheduled == null)
+ {
+ // Schedule next available or let the last in line schedule it
+ Forward forward = forwardQueue.poll();
+ scheduled = forward.schedule(executor);
+ scheduled.addListener((e) -> {
+ scheduled = null;
+ maybeScheduleNext(executor);
+ });
+ }
+ }
+
+ private static class Forward
+ {
+ final long arrivedAt;
+ final long latency;
+ final Runnable handler;
+
+ private Forward(long arrivedAt, long latency, Runnable handler)
+ {
+ this.arrivedAt = arrivedAt;
+ this.latency = latency;
+ this.handler = handler;
+ }
+
+ ScheduledFuture schedule(EventExecutor executor)
+ {
+ long now = System.currentTimeMillis();
+ long elapsed = now - arrivedAt;
+ long runIn = latency - elapsed;
+
+ if (runIn > 0)
+ return executor.schedule(handler, runIn, TimeUnit.MILLISECONDS);
+ else
+ return executor.schedule(handler, 0, TimeUnit.MILLISECONDS);
+ }
+ }
+
+ private static class ForwardNormally implements ForwardStrategy
+ {
+ static ForwardNormally instance = new ForwardNormally();
+
+ public Forward forward(ChannelHandlerContext ctx, Object msg)
+ {
+ return new Forward(System.currentTimeMillis(),
+ 0,
+ () -> ctx.fireChannelRead(msg));
+ }
+ }
+
+ public interface ForwardStrategy
+ {
+ public Forward forward(ChannelHandlerContext ctx, Object msg);
+ }
+
+ private static class ForwardWithLatency implements ForwardStrategy
+ {
+ private final long latency;
+ private final TimeUnit timeUnit;
+
+ ForwardWithLatency(long latency, TimeUnit timeUnit)
+ {
+ this.latency = latency;
+ this.timeUnit = timeUnit;
+ }
+
+ public Forward forward(ChannelHandlerContext ctx, Object msg)
+ {
+ return new Forward(System.currentTimeMillis(),
+ timeUnit.toMillis(latency),
+ () -> ctx.fireChannelRead(msg));
+ }
+ }
+
+ private static class CloseAfterRead implements ForwardStrategy
+ {
+ private final Runnable afterClose;
+
+ CloseAfterRead(Runnable afterClose)
+ {
+ this.afterClose = afterClose;
+ }
+
+ public Forward forward(ChannelHandlerContext ctx, Object msg)
+ {
+ return new Forward(System.currentTimeMillis(),
+ 0,
+ () -> {
+ ctx.channel().close().syncUninterruptibly();
+ afterClose.run();
+ });
+ }
+ }
+
+ private static class TransformPayload<T> implements ForwardStrategy
+ {
+ private final Function<T, T> fn;
+
+ TransformPayload(Function<T, T> fn)
+ {
+ this.fn = fn;
+ }
+
+ public Forward forward(ChannelHandlerContext ctx, Object msg)
+ {
+ return new Forward(System.currentTimeMillis(),
+ 0,
+ () -> ctx.fireChannelRead(fn.apply((T) msg)));
+ }
+ }
+
+ public static class Controller
+ {
+ private volatile InboundProxyHandler.ForwardStrategy forwardStrategy;
+ private volatile Runnable onRead = () -> {};
+ private volatile Runnable onDisconnect = () -> {};
+
+ public Controller()
+ {
+ this.forwardStrategy = ForwardNormally.instance;
+ }
+ public void onRead(Runnable onRead)
+ {
+ this.onRead = onRead;
+ }
+
+ public void onDisconnect(Runnable onDisconnect)
+ {
+ this.onDisconnect = onDisconnect;
+ }
+
+ public void reset()
+ {
+ this.forwardStrategy = ForwardNormally.instance;
+ }
+
+ public void withLatency(long latency, TimeUnit timeUnit)
+ {
+ this.forwardStrategy = new ForwardWithLatency(latency, timeUnit);
+ }
+
+ public void withCloseAfterRead(Runnable afterClose)
+ {
+ this.forwardStrategy = new CloseAfterRead(afterClose);
+ }
+
+ public <T> void withPayloadTransform(Function<T, T> fn)
+ {
+ this.forwardStrategy = new TransformPayload<>(fn);
+ }
+ }
+
+}
diff --git a/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java b/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java
new file mode 100644
index 0000000..d070f56
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.proxy;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.local.LocalAddress;
+import io.netty.channel.local.LocalChannel;
+import io.netty.channel.local.LocalServerChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.handler.logging.LogLevel;
+import io.netty.handler.logging.LoggingHandler;
+
+public class ProxyHandlerTest
+{
+ private final Object PAYLOAD = new Object();
+
+ @Test
+ public void testLatency() throws Throwable
+ {
+ test((proxyHandler, testHandler, channel) -> {
+ int count = 1;
+ long latency = 100;
+ CountDownLatch latch = new CountDownLatch(count);
+ long start = System.currentTimeMillis();
+ testHandler.onRead = new Consumer<Object>()
+ {
+ int last = -1;
+ public void accept(Object o)
+ {
+ // Make sure that order is preserved
+ Assert.assertEquals(last + 1, o);
+ last = (int) o;
+
+ long elapsed = System.currentTimeMillis() - start;
+ Assert.assertTrue("Latency was:" + elapsed, elapsed > latency);
+ latch.countDown();
+ }
+ };
+
+ proxyHandler.withLatency(latency, TimeUnit.MILLISECONDS);
+
+ for (int i = 0; i < count; i++)
+ {
+ ByteBuf bb = Unpooled.buffer(Integer.BYTES);
+ bb.writeInt(i);
+ channel.writeAndFlush(i);
+ }
+
+ Assert.assertTrue(latch.await(10, TimeUnit.SECONDS));
+ });
+ }
+
+ @Test
+ public void testNormalDelivery() throws Throwable
+ {
+ test((proxyHandler, testHandler, channelPipeline) -> {
+ int count = 10;
+ CountDownLatch latch = new CountDownLatch(count);
+ AtomicLong end = new AtomicLong();
+ testHandler.onRead = (o) -> {
+ end.set(System.currentTimeMillis());
+ latch.countDown();
+ };
+
+ for (int i = 0; i < count; i++)
+ channelPipeline.writeAndFlush(PAYLOAD);
+ Assert.assertTrue(latch.await(10, TimeUnit.SECONDS));
+
+ });
+ }
+
+ @Test
+ public void testLatencyForMany() throws Throwable
+ {
+ class Event {
+ private final long latency;
+ private final long start;
+ private final int idx;
+
+ Event(long latency, int idx)
+ {
+ this.latency = latency;
+ this.start = System.currentTimeMillis();
+ this.idx = idx;
+ }
+ }
+
+ test((proxyHandler, testHandler, channel) -> {
+ int count = 150;
+ CountDownLatch latch = new CountDownLatch(count);
+ AtomicInteger counter = new AtomicInteger();
+ testHandler.onRead = new Consumer<Object>()
+ {
+ int lastSeen = -1;
+ public void accept(Object o)
+ {
+ Event e = (Event) o;
+ Assert.assertEquals(lastSeen + 1, e.idx);
+ lastSeen = e.idx;
+ long elapsed = System.currentTimeMillis() - e.start;
+ Assert.assertTrue(elapsed >= e.latency);
+ counter.incrementAndGet();
+ latch.countDown();
+ }
+ };
+
+ int idx = 0;
+ for (int i = 0; i < count / 3; i++)
+ {
+ for (long latency : new long[]{ 100, 200, 0 })
+ {
+ proxyHandler.withLatency(latency, TimeUnit.MILLISECONDS);
+ CountDownLatch read = new CountDownLatch(1);
+ proxyHandler.onRead(read::countDown);
+ channel.writeAndFlush(new Event(latency, idx++));
+ Assert.assertTrue(read.await(10, TimeUnit.SECONDS));
+ }
+ }
+
+ Assert.assertTrue(latch.await(10, TimeUnit.SECONDS));
+ Assert.assertEquals(counter.get(), count);
+ });
+ }
+
+ private interface DoTest
+ {
+ public void doTest(InboundProxyHandler.Controller proxy, TestHandler testHandler, Channel channel) throws Throwable;
+ }
+
+
+ public void test(DoTest test) throws Throwable
+ {
+ EventLoopGroup serverGroup = new NioEventLoopGroup(1);
+ EventLoopGroup clientGroup = new NioEventLoopGroup(1);
+
+ InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller();
+ InboundProxyHandler proxyHandler = new InboundProxyHandler(controller);
+ TestHandler testHandler = new TestHandler();
+
+ ServerBootstrap sb = new ServerBootstrap();
+ sb.group(serverGroup)
+ .channel(LocalServerChannel.class)
+ .childHandler(new ChannelInitializer<LocalChannel>() {
+ @Override
+ public void initChannel(LocalChannel ch)
+ {
+ ch.pipeline()
+ .addLast(proxyHandler)
+ .addLast(testHandler);
+ }
+ })
+ .childOption(ChannelOption.AUTO_READ, false);
+
+ Bootstrap cb = new Bootstrap();
+ cb.group(clientGroup)
+ .channel(LocalChannel.class)
+ .handler(new ChannelInitializer<LocalChannel>() {
+ @Override
+ public void initChannel(LocalChannel ch) throws Exception {
+ ch.pipeline()
+ .addLast(new LoggingHandler(LogLevel.TRACE));
+ }
+ });
+
+ final LocalAddress addr = new LocalAddress("test");
+
+ Channel serverChannel = sb.bind(addr).sync().channel();
+
+ Channel clientChannel = cb.connect(addr).sync().channel();
+ test.doTest(controller, testHandler, clientChannel);
+
+ clientChannel.close();
+ serverChannel.close();
+ serverGroup.shutdownGracefully();
+ clientGroup.shutdownGracefully();
+ }
+
+
+ public static class TestHandler extends ChannelInboundHandlerAdapter
+ {
+ private Consumer<Object> onRead = (o) -> {};
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg)
+ {
+ onRead.accept(msg);
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org