You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by al...@apache.org on 2019/06/12 15:25:46 UTC

[cassandra] branch trunk updated (9cd1545 -> 532c033)

This is an automated email from the ASF dual-hosted git repository.

aleksey pushed a change to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git.


    omit 9cd1545  Fix up test after the most recent renames
    omit aed15bf  Introduce a proxy test handler, extra unit tests for connection closure and message expirations
     new 532c033  Introduce a proxy test handler, extra unit tests for connection closure and message expirations

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (9cd1545)
            \
             N -- N -- N   refs/heads/trunk (532c033)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org


[cassandra] 01/01: Introduce a proxy test handler, extra unit tests for connection closure and message expirations

Posted by al...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

aleksey pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git

commit 532c033449dc5b79dd20782e91a243469014d18c
Author: Alex Petrov <ol...@gmail.com>
AuthorDate: Wed Jun 12 15:02:37 2019 +0100

    Introduce a proxy test handler, extra unit tests for connection closure
    and message expirations
    
    patch by Alex Petrov; reviewed by Aleksey Yeschenko and Benedict Elliott
    Smith for CASSANDRA-15066
---
 .../cassandra/net/ProxyHandlerConnectionsTest.java | 405 +++++++++++++++++++++
 .../cassandra/net/proxy/InboundProxyHandler.java   | 234 ++++++++++++
 .../cassandra/net/proxy/ProxyHandlerTest.java      | 222 +++++++++++
 3 files changed, 861 insertions(+)

diff --git a/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java b/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java
new file mode 100644
index 0000000..270a910
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Supplier;
+import java.util.function.ToLongFunction;
+
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.io.IVersionedAsymmetricSerializer;
+import org.apache.cassandra.io.IVersionedSerializer;
+import org.apache.cassandra.io.util.DataInputPlus;
+import org.apache.cassandra.io.util.DataOutputPlus;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.net.proxy.InboundProxyHandler;
+import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.Pair;
+
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.apache.cassandra.net.ConnectionTest.SETTINGS;
+import static org.apache.cassandra.net.OutboundConnectionSettings.Framing.CRC;
+import static org.apache.cassandra.utils.MonotonicClock.approxTime;
+
+public class ProxyHandlerConnectionsTest
+{
+    private static final SocketFactory factory = new SocketFactory();
+
+    private final Map<Verb, Supplier<? extends IVersionedAsymmetricSerializer<?, ?>>> serializers = new HashMap<>();
+    private final Map<Verb, Supplier<? extends IVerbHandler<?>>> handlers = new HashMap<>();
+    private final Map<Verb, ToLongFunction<TimeUnit>> timeouts = new HashMap<>();
+
+    private void unsafeSetSerializer(Verb verb, Supplier<? extends IVersionedAsymmetricSerializer<?, ?>> supplier) throws Throwable
+    {
+        serializers.putIfAbsent(verb, verb.unsafeSetSerializer(supplier));
+    }
+
+    protected void unsafeSetHandler(Verb verb, Supplier<? extends IVerbHandler<?>> supplier) throws Throwable
+    {
+        handlers.putIfAbsent(verb, verb.unsafeSetHandler(supplier));
+    }
+
+    private void unsafeSetExpiration(Verb verb, ToLongFunction<TimeUnit> expiration) throws Throwable
+    {
+        timeouts.putIfAbsent(verb, verb.unsafeSetExpiration(expiration));
+    }
+
+    @BeforeClass
+    public static void startup()
+    {
+        DatabaseDescriptor.daemonInitialization();
+    }
+
+    @After
+    public void cleanup() throws Throwable
+    {
+        for (Map.Entry<Verb, Supplier<? extends IVersionedAsymmetricSerializer<?, ?>>> e : serializers.entrySet())
+            e.getKey().unsafeSetSerializer(e.getValue());
+        serializers.clear();
+        for (Map.Entry<Verb, Supplier<? extends IVerbHandler<?>>> e : handlers.entrySet())
+            e.getKey().unsafeSetHandler(e.getValue());
+        handlers.clear();
+        for (Map.Entry<Verb, ToLongFunction<TimeUnit>> e : timeouts.entrySet())
+            e.getKey().unsafeSetExpiration(e.getValue());
+        timeouts.clear();
+    }
+
+    @Test
+    public void testExpireInbound() throws Throwable
+    {
+        DatabaseDescriptor.setCrossNodeTimeout(true);
+        testOneManual((settings, inbound, outbound, endpoint, handler) -> {
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+
+            CountDownLatch connectionLatch = new CountDownLatch(1);
+            unsafeSetHandler(Verb._TEST_1, () -> v -> {
+                connectionLatch.countDown();
+            });
+            outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            connectionLatch.await(10, SECONDS);
+            Assert.assertEquals(0, connectionLatch.getCount());
+
+            // Slow things down
+            unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(50, MILLISECONDS));
+            handler.withLatency(100, MILLISECONDS);
+
+            unsafeSetHandler(Verb._TEST_1, () -> v -> {
+                throw new RuntimeException("Should have not been triggered " + v);
+            });
+            int expireMessages = 10;
+            for (int i = 0; i < expireMessages; i++)
+                outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+
+            InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint);
+            waitForCondition(() -> handlers.expiredCount() == expireMessages);
+            Assert.assertEquals(expireMessages, handlers.expiredCount());
+        });
+    }
+
+    @Test
+    public void testExpireSome() throws Throwable
+    {
+        DatabaseDescriptor.setCrossNodeTimeout(true);
+        testOneManual((settings, inbound, outbound, endpoint, handler) -> {
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+            connect(outbound);
+
+            AtomicInteger counter = new AtomicInteger();
+            unsafeSetHandler(Verb._TEST_1, () -> v -> {
+                counter.incrementAndGet();
+            });
+
+            int expireMessages = 10;
+            for (int i = 0; i < expireMessages; i++)
+                outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            waitForCondition(() -> counter.get() == 10);
+
+            unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(50, MILLISECONDS));
+            handler.withLatency(100, MILLISECONDS);
+
+            InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint);
+            for (int i = 0; i < expireMessages; i++)
+                outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            waitForCondition(() -> handlers.expiredCount() == 10);
+
+            handler.withLatency(2, MILLISECONDS);
+
+            for (int i = 0; i < expireMessages; i++)
+                outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            waitForCondition(() -> counter.get() == 20);
+        });
+    }
+
+    @Test
+    public void testExpireSomeFromBatch() throws Throwable
+    {
+        DatabaseDescriptor.setCrossNodeTimeout(true);
+        testManual((settings, inbound, outbound, endpoint, handler) -> {
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+            connect(outbound);
+
+            Message msg = Message.out(Verb._TEST_1, 1L);
+            int messageSize = msg.serializedSize(MessagingService.current_version);
+            DatabaseDescriptor.setInternodeMaxMessageSizeInBytes(messageSize * 40);
+
+            AtomicInteger counter = new AtomicInteger();
+            unsafeSetHandler(Verb._TEST_1, () -> v -> {
+                counter.incrementAndGet();
+            });
+
+            unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(200, MILLISECONDS));
+            handler.withLatency(100, MILLISECONDS);
+
+            int expireMessages = 20;
+            long nanoTime = approxTime.now();
+            CountDownLatch enqueueDone = new CountDownLatch(1);
+            outbound.unsafeRunOnDelivery(() -> Uninterruptibles.awaitUninterruptibly(enqueueDone, 10, SECONDS));
+            for (int i = 0; i < expireMessages; i++)
+            {
+                boolean expire = i % 2 == 0;
+                Message.Builder builder = Message.builder(Verb._TEST_1, 1L);
+
+                if (settings.right.acceptVersions == ConnectionTest.legacy)
+                {
+                    // backdate messages; leave 50 milliseconds to leave outbound path
+                    builder.withCreatedAt(nanoTime - (expire ? 0 : MILLISECONDS.toNanos(150)));
+                }
+                else
+                {
+                    // Give messages 50 milliseconds to leave outbound path
+                    builder.withCreatedAt(nanoTime)
+                           .withExpiresAt(nanoTime + (expire ? MILLISECONDS.toNanos(50) : MILLISECONDS.toNanos(1000)));
+                }
+                outbound.enqueue(builder.build());
+            }
+            enqueueDone.countDown();
+
+            InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint);
+            waitForCondition(() -> handlers.expiredCount() == 10 && counter.get() == 10,
+                             () -> String.format("Expired: %d, Arrived: %d", handlers.expiredCount(), counter.get()));
+        });
+    }
+
+    @Test
+    public void suddenDisconnect() throws Throwable
+    {
+        testManual((settings, inbound, outbound, endpoint, handler) -> {
+            handler.onDisconnect(() -> handler.reset());
+
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+            connect(outbound);
+
+            CountDownLatch closeLatch = new CountDownLatch(1);
+            handler.withCloseAfterRead(closeLatch::countDown);
+            AtomicInteger counter = new AtomicInteger();
+            unsafeSetHandler(Verb._TEST_1, () -> v -> counter.incrementAndGet());
+
+            outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            waitForCondition(() -> !outbound.isConnected());
+
+            connect(outbound);
+            Assert.assertTrue(outbound.isConnected());
+            Assert.assertEquals(0, counter.get());
+        });
+    }
+
+    @Test
+    public void testCorruptionOnHandshake() throws Throwable
+    {
+        testManual((settings, inbound, outbound, endpoint, handler) -> {
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+            // Invalid CRC
+            handler.withPayloadTransform(msg -> {
+                ByteBuf bb = (ByteBuf) msg;
+                bb.setByte(bb.readableBytes() / 2, 0xffff);
+                return msg;
+            });
+            tryConnect(outbound, 1, SECONDS, false);
+            Assert.assertTrue(!outbound.isConnected());
+
+            // Invalid protocol magic
+            handler.withPayloadTransform(msg -> {
+                ByteBuf bb = (ByteBuf) msg;
+                bb.setByte(0, 0xffff);
+                return msg;
+            });
+            tryConnect(outbound, 1, SECONDS, false);
+            Assert.assertTrue(!outbound.isConnected());
+            if (settings.right.framing == CRC)
+            {
+                Assert.assertEquals(2, outbound.connectionAttempts());
+                Assert.assertEquals(0, outbound.successfulConnections());
+            }
+        });
+    }
+
+    private static void waitForCondition(Supplier<Boolean> cond) throws Throwable
+    {
+        CompletableFuture.runAsync(() -> {
+            while (!cond.get()) {}
+        }).get(10, SECONDS);
+    }
+
+    private static void waitForCondition(Supplier<Boolean> cond, Supplier<String> s) throws Throwable
+    {
+        try
+        {
+            CompletableFuture.runAsync(() -> {
+                while (!cond.get()) {}
+            }).get(10, SECONDS);
+        }
+        catch (TimeoutException e)
+        {
+            throw new AssertionError(s.get());
+        }
+    }
+
+    private static class FakePayloadSerializer implements IVersionedSerializer<Long>
+    {
+        private final int size;
+        private FakePayloadSerializer()
+        {
+            this(1);
+        }
+
+        // Takes long and repeats it size times
+        private FakePayloadSerializer(int size)
+        {
+            this.size = size;
+        }
+
+        public void serialize(Long i, DataOutputPlus out, int version) throws IOException
+        {
+            for (int j = 0; j < size; j++)
+            {
+                out.writeLong(i);
+            }
+        }
+
+        public Long deserialize(DataInputPlus in, int version) throws IOException
+        {
+            long l = in.readLong();
+            for (int i = 0; i < size - 1; i++)
+            {
+                if (in.readLong() != l)
+                    throw new AssertionError();
+            }
+
+            return l;
+        }
+
+        public long serializedSize(Long t, int version)
+        {
+            return Long.BYTES * size;
+        }
+    }
+    interface ManualSendTest
+    {
+        void accept(Pair<InboundConnectionSettings, OutboundConnectionSettings> settings, InboundSockets inbound, OutboundConnection outbound, InetAddressAndPort endpoint, InboundProxyHandler.Controller handler) throws Throwable;
+    }
+
+    private void testManual(ManualSendTest test) throws Throwable
+    {
+        for (ConnectionTest.Settings s: SETTINGS)
+        {
+            doTestManual(s, test);
+            cleanup();
+        }
+    }
+
+    private void testOneManual(ManualSendTest test) throws Throwable
+    {
+        testOneManual(test, 1);
+    }
+
+    private void testOneManual(ManualSendTest test, int i) throws Throwable
+    {
+        ConnectionTest.Settings s = SETTINGS.get(i);
+        doTestManual(s, test);
+        cleanup();
+    }
+
+    private void doTestManual(ConnectionTest.Settings settings, ManualSendTest test) throws Throwable
+    {
+        InetAddressAndPort endpoint = FBUtilities.getBroadcastAddressAndPort();
+
+        InboundConnectionSettings inboundSettings = settings.inbound.apply(new InboundConnectionSettings())
+                                                                    .withBindAddress(endpoint)
+                                                                    .withSocketFactory(factory);
+
+        InboundSockets inbound = new InboundSockets(Collections.singletonList(inboundSettings));
+
+        OutboundConnectionSettings outboundSettings = settings.outbound.apply(new OutboundConnectionSettings(endpoint))
+                                                                       .withConnectTo(endpoint)
+                                                                       .withDefaultReserveLimits()
+                                                                       .withSocketFactory(factory);
+
+        ResourceLimits.EndpointAndGlobal reserveCapacityInBytes = new ResourceLimits.EndpointAndGlobal(new ResourceLimits.Concurrent(outboundSettings.applicationSendQueueReserveEndpointCapacityInBytes), outboundSettings.applicationSendQueueReserveGlobalCapacityInBytes);
+        OutboundConnection outbound = new OutboundConnection(settings.type, outboundSettings, reserveCapacityInBytes);
+        try
+        {
+            InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller();
+            inbound.open(pipeline -> {
+                InboundProxyHandler handler = new InboundProxyHandler(controller);
+                pipeline.addLast(handler);
+            }).sync();
+            test.accept(Pair.create(inboundSettings, outboundSettings), inbound, outbound, endpoint, controller);
+        }
+        finally
+        {
+            outbound.close(false);
+            inbound.close().get(30L, SECONDS);
+            outbound.close(false).get(30L, SECONDS);
+            MessagingService.instance().messageHandlers.clear();
+        }
+    }
+
+    private void connect(OutboundConnection outbound) throws Throwable
+    {
+        tryConnect(outbound, 10, SECONDS, true);
+    }
+
+    private void tryConnect(OutboundConnection outbound, long timeout, TimeUnit timeUnit, boolean throwOnFailure) throws Throwable
+    {
+        CountDownLatch connectionLatch = new CountDownLatch(1);
+        unsafeSetHandler(Verb._TEST_1, () -> v -> {
+            connectionLatch.countDown();
+        });
+        outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+        connectionLatch.await(timeout, timeUnit);
+        if (throwOnFailure)
+            Assert.assertEquals(0, connectionLatch.getCount());
+    }
+}
diff --git a/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java b/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java
new file mode 100644
index 0000000..7e3b004
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.proxy;
+
+import java.util.ArrayDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.util.concurrent.EventExecutor;
+import io.netty.util.concurrent.ScheduledFuture;
+
+public class InboundProxyHandler extends ChannelInboundHandlerAdapter
+{
+    private final ArrayDeque<Forward> forwardQueue;
+    private ScheduledFuture scheduled = null;
+    private final Controller controller;
+    public InboundProxyHandler(Controller controller)
+    {
+        this.controller = controller;
+        this.forwardQueue = new ArrayDeque<>(1024);
+    }
+
+    @Override
+    public void channelActive(ChannelHandlerContext ctx) throws Exception
+    {
+        super.channelActive(ctx);
+        ctx.read();
+    }
+
+    @Override
+    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+        controller.onDisconnect.run();
+
+        if (scheduled != null)
+        {
+            scheduled.cancel(true);
+            scheduled = null;
+        }
+
+        if (!forwardQueue.isEmpty())
+            forwardQueue.clear();
+
+        super.channelInactive(ctx);
+    }
+
+
+    @Override
+    public void channelRead(ChannelHandlerContext ctx, Object msg)
+    {
+        Forward forward = controller.forwardStrategy.forward(ctx, msg);
+        forwardQueue.offer(forward);
+        maybeScheduleNext(ctx.channel().eventLoop());
+        controller.onRead.run();
+        ctx.channel().read();
+    }
+
+    private void maybeScheduleNext(EventExecutor executor)
+    {
+        if (forwardQueue.isEmpty())
+        {
+            // Ran out of items to process
+            scheduled = null;
+        }
+        else if (scheduled == null)
+        {
+            // Schedule next available or let the last in line schedule it
+            Forward forward = forwardQueue.poll();
+            scheduled = forward.schedule(executor);
+            scheduled.addListener((e) -> {
+                scheduled = null;
+                maybeScheduleNext(executor);
+            });
+        }
+    }
+
+    private static class Forward
+    {
+        final long arrivedAt;
+        final long latency;
+        final Runnable handler;
+
+        private Forward(long arrivedAt, long latency, Runnable handler)
+        {
+            this.arrivedAt = arrivedAt;
+            this.latency = latency;
+            this.handler = handler;
+        }
+
+        ScheduledFuture schedule(EventExecutor executor)
+        {
+            long now = System.currentTimeMillis();
+            long elapsed = now - arrivedAt;
+            long runIn = latency - elapsed;
+
+            if (runIn > 0)
+                return executor.schedule(handler, runIn, TimeUnit.MILLISECONDS);
+            else
+                return executor.schedule(handler, 0, TimeUnit.MILLISECONDS);
+        }
+    }
+
+    private static class ForwardNormally implements ForwardStrategy
+    {
+        static ForwardNormally instance = new ForwardNormally();
+
+        public Forward forward(ChannelHandlerContext ctx, Object msg)
+        {
+            return new Forward(System.currentTimeMillis(),
+                               0,
+                               () -> ctx.fireChannelRead(msg));
+        }
+    }
+
+    public interface ForwardStrategy
+    {
+        public Forward forward(ChannelHandlerContext ctx, Object msg);
+    }
+
+    private static class ForwardWithLatency implements ForwardStrategy
+    {
+        private final long latency;
+        private final TimeUnit timeUnit;
+
+        ForwardWithLatency(long latency, TimeUnit timeUnit)
+        {
+            this.latency = latency;
+            this.timeUnit = timeUnit;
+        }
+
+        public Forward forward(ChannelHandlerContext ctx, Object msg)
+        {
+            return new Forward(System.currentTimeMillis(),
+                               timeUnit.toMillis(latency),
+                               () -> ctx.fireChannelRead(msg));
+        }
+    }
+
+    private static class CloseAfterRead implements ForwardStrategy
+    {
+        private final Runnable afterClose;
+
+        CloseAfterRead(Runnable afterClose)
+        {
+            this.afterClose = afterClose;
+        }
+
+        public Forward forward(ChannelHandlerContext ctx, Object msg)
+        {
+            return  new Forward(System.currentTimeMillis(),
+                                0,
+                                () -> {
+                                    ctx.channel().close().syncUninterruptibly();
+                                    afterClose.run();
+                                });
+        }
+    }
+
+    private static class TransformPayload<T> implements ForwardStrategy
+    {
+        private final Function<T, T> fn;
+
+        TransformPayload(Function<T, T> fn)
+        {
+            this.fn = fn;
+        }
+
+        public Forward forward(ChannelHandlerContext ctx, Object msg)
+        {
+            return new Forward(System.currentTimeMillis(),
+                               0,
+                               () -> ctx.fireChannelRead(fn.apply((T) msg)));
+        }
+    }
+
+    public static class Controller
+    {
+        private volatile InboundProxyHandler.ForwardStrategy forwardStrategy;
+        private volatile Runnable onRead = () -> {};
+        private volatile Runnable onDisconnect = () -> {};
+
+        public Controller()
+        {
+            this.forwardStrategy = ForwardNormally.instance;
+        }
+        public void onRead(Runnable onRead)
+        {
+            this.onRead = onRead;
+        }
+
+        public void onDisconnect(Runnable onDisconnect)
+        {
+            this.onDisconnect = onDisconnect;
+        }
+
+        public void reset()
+        {
+            this.forwardStrategy = ForwardNormally.instance;
+        }
+
+        public void withLatency(long latency, TimeUnit timeUnit)
+        {
+            this.forwardStrategy = new ForwardWithLatency(latency, timeUnit);
+        }
+
+        public void withCloseAfterRead(Runnable afterClose)
+        {
+            this.forwardStrategy = new CloseAfterRead(afterClose);
+        }
+
+        public <T> void withPayloadTransform(Function<T, T> fn)
+        {
+            this.forwardStrategy = new TransformPayload<>(fn);
+        }
+    }
+
+}
diff --git a/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java b/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java
new file mode 100644
index 0000000..d070f56
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.proxy;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.local.LocalAddress;
+import io.netty.channel.local.LocalChannel;
+import io.netty.channel.local.LocalServerChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.handler.logging.LogLevel;
+import io.netty.handler.logging.LoggingHandler;
+
+public class ProxyHandlerTest
+{
+    private final Object PAYLOAD = new Object();
+
+    @Test
+    public void testLatency() throws Throwable
+    {
+        test((proxyHandler, testHandler, channel) -> {
+            int count = 1;
+            long latency = 100;
+            CountDownLatch latch = new CountDownLatch(count);
+            long start = System.currentTimeMillis();
+            testHandler.onRead = new Consumer<Object>()
+            {
+                int last = -1;
+                public void accept(Object o)
+                {
+                    // Make sure that order is preserved
+                    Assert.assertEquals(last + 1, o);
+                    last = (int) o;
+
+                    long elapsed = System.currentTimeMillis() - start;
+                    Assert.assertTrue("Latency was:" + elapsed, elapsed > latency);
+                    latch.countDown();
+                }
+            };
+
+            proxyHandler.withLatency(latency, TimeUnit.MILLISECONDS);
+
+            for (int i = 0; i < count; i++)
+            {
+                ByteBuf bb = Unpooled.buffer(Integer.BYTES);
+                bb.writeInt(i);
+                channel.writeAndFlush(i);
+            }
+
+            Assert.assertTrue(latch.await(10, TimeUnit.SECONDS));
+        });
+    }
+
+    @Test
+    public void testNormalDelivery() throws Throwable
+    {
+        test((proxyHandler, testHandler, channelPipeline) -> {
+            int count = 10;
+            CountDownLatch latch = new CountDownLatch(count);
+            AtomicLong end = new AtomicLong();
+            testHandler.onRead = (o) -> {
+                end.set(System.currentTimeMillis());
+                latch.countDown();
+            };
+
+            for (int i = 0; i < count; i++)
+                channelPipeline.writeAndFlush(PAYLOAD);
+            Assert.assertTrue(latch.await(10, TimeUnit.SECONDS));
+
+        });
+    }
+
+    @Test
+    public void testLatencyForMany() throws Throwable
+    {
+        class Event {
+            private final long latency;
+            private final long start;
+            private final int idx;
+
+            Event(long latency, int idx)
+            {
+                this.latency = latency;
+                this.start = System.currentTimeMillis();
+                this.idx = idx;
+            }
+        }
+
+        test((proxyHandler, testHandler, channel) -> {
+            int count = 150;
+            CountDownLatch latch = new CountDownLatch(count);
+            AtomicInteger counter = new AtomicInteger();
+            testHandler.onRead = new Consumer<Object>()
+            {
+                int lastSeen = -1;
+                public void accept(Object o)
+                {
+                    Event e = (Event) o;
+                    Assert.assertEquals(lastSeen + 1, e.idx);
+                    lastSeen = e.idx;
+                    long elapsed = System.currentTimeMillis() - e.start;
+                    Assert.assertTrue(elapsed >= e.latency);
+                    counter.incrementAndGet();
+                    latch.countDown();
+                }
+            };
+
+            int idx = 0;
+            for (int i = 0; i < count / 3; i++)
+            {
+                for (long latency : new long[]{ 100, 200, 0 })
+                {
+                    proxyHandler.withLatency(latency, TimeUnit.MILLISECONDS);
+                    CountDownLatch read = new CountDownLatch(1);
+                    proxyHandler.onRead(read::countDown);
+                    channel.writeAndFlush(new Event(latency, idx++));
+                    Assert.assertTrue(read.await(10, TimeUnit.SECONDS));
+                }
+            }
+
+            Assert.assertTrue(latch.await(10, TimeUnit.SECONDS));
+            Assert.assertEquals(counter.get(), count);
+        });
+    }
+
+    private interface DoTest
+    {
+        public void doTest(InboundProxyHandler.Controller proxy, TestHandler testHandler, Channel channel) throws Throwable;
+    }
+
+
+    public void test(DoTest test) throws Throwable
+    {
+        EventLoopGroup serverGroup = new NioEventLoopGroup(1);
+        EventLoopGroup clientGroup = new NioEventLoopGroup(1);
+
+        InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller();
+        InboundProxyHandler proxyHandler = new InboundProxyHandler(controller);
+        TestHandler testHandler = new TestHandler();
+
+        ServerBootstrap sb = new ServerBootstrap();
+        sb.group(serverGroup)
+          .channel(LocalServerChannel.class)
+          .childHandler(new ChannelInitializer<LocalChannel>() {
+              @Override
+              public void initChannel(LocalChannel ch)
+              {
+                  ch.pipeline()
+                    .addLast(proxyHandler)
+                    .addLast(testHandler);
+              }
+          })
+          .childOption(ChannelOption.AUTO_READ, false);
+
+        Bootstrap cb = new Bootstrap();
+        cb.group(clientGroup)
+          .channel(LocalChannel.class)
+          .handler(new ChannelInitializer<LocalChannel>() {
+              @Override
+              public void initChannel(LocalChannel ch) throws Exception {
+                  ch.pipeline()
+                    .addLast(new LoggingHandler(LogLevel.TRACE));
+              }
+          });
+
+        final LocalAddress addr = new LocalAddress("test");
+
+        Channel serverChannel = sb.bind(addr).sync().channel();
+
+        Channel clientChannel = cb.connect(addr).sync().channel();
+        test.doTest(controller, testHandler, clientChannel);
+
+        clientChannel.close();
+        serverChannel.close();
+        serverGroup.shutdownGracefully();
+        clientGroup.shutdownGracefully();
+    }
+
+
+    public static class TestHandler extends ChannelInboundHandlerAdapter
+    {
+        private Consumer<Object> onRead = (o) -> {};
+        @Override
+        public void channelRead(ChannelHandlerContext ctx, Object msg)
+        {
+            onRead.accept(msg);
+        }
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org