You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by al...@apache.org on 2019/06/12 14:54:53 UTC

[cassandra] 02/02: Introduce a proxy test handler, extra unit tests for connection closure and message expirations

This is an automated email from the ASF dual-hosted git repository.

aleksey pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git

commit aed15bfb01e753f9bac8b149b382c7a7c8d33183
Author: Alex Petrov <ol...@gmail.com>
AuthorDate: Wed Jun 12 15:02:37 2019 +0100

    Introduce a proxy test handler, extra unit tests for connection closure
    and message expirations
    
    patch by Alex Petrov; reviewed by Aleksey Yeschenko and Benedict Elliott
    Smith for CASSANDRA-15066
---
 .../cassandra/net/ProxyHandlerConnectionsTest.java | 405 +++++++++++++++++++++
 .../cassandra/net/proxy/InboundProxyHandler.java   | 234 ++++++++++++
 .../cassandra/net/proxy/ProxyHandlerTest.java      | 222 +++++++++++
 3 files changed, 861 insertions(+)

diff --git a/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java b/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java
new file mode 100644
index 0000000..ae41cce
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/ProxyHandlerConnectionsTest.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Supplier;
+import java.util.function.ToLongFunction;
+
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.io.IVersionedAsymmetricSerializer;
+import org.apache.cassandra.io.IVersionedSerializer;
+import org.apache.cassandra.io.util.DataInputPlus;
+import org.apache.cassandra.io.util.DataOutputPlus;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.net.proxy.InboundProxyHandler;
+import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.Pair;
+
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.apache.cassandra.net.ConnectionTest.SETTINGS;
+import static org.apache.cassandra.net.OutboundConnectionSettings.Framing.CRC;
+import static org.apache.cassandra.utils.MonotonicClock.approxTime;
+
+public class ProxyHandlerConnectionsTest
+{
+    private static final SocketFactory factory = new SocketFactory();
+
+    private final Map<Verb, Supplier<? extends IVersionedAsymmetricSerializer<?, ?>>> serializers = new HashMap<>();
+    private final Map<Verb, Supplier<? extends IVerbHandler<?>>> handlers = new HashMap<>();
+    private final Map<Verb, ToLongFunction<TimeUnit>> timeouts = new HashMap<>();
+
+    private void unsafeSetSerializer(Verb verb, Supplier<? extends IVersionedAsymmetricSerializer<?, ?>> supplier) throws Throwable
+    {
+        serializers.putIfAbsent(verb, verb.unsafeSetSerializer(supplier));
+    }
+
+    protected void unsafeSetHandler(Verb verb, Supplier<? extends IVerbHandler<?>> supplier) throws Throwable
+    {
+        handlers.putIfAbsent(verb, verb.unsafeSetHandler(supplier));
+    }
+
+    private void unsafeSetExpiration(Verb verb, ToLongFunction<TimeUnit> expiration) throws Throwable
+    {
+        timeouts.putIfAbsent(verb, verb.unsafeSetExpiration(expiration));
+    }
+
+    @BeforeClass
+    public static void startup()
+    {
+        DatabaseDescriptor.daemonInitialization();
+    }
+
+    @After
+    public void cleanup() throws Throwable
+    {
+        for (Map.Entry<Verb, Supplier<? extends IVersionedAsymmetricSerializer<?, ?>>> e : serializers.entrySet())
+            e.getKey().unsafeSetSerializer(e.getValue());
+        serializers.clear();
+        for (Map.Entry<Verb, Supplier<? extends IVerbHandler<?>>> e : handlers.entrySet())
+            e.getKey().unsafeSetHandler(e.getValue());
+        handlers.clear();
+        for (Map.Entry<Verb, ToLongFunction<TimeUnit>> e : timeouts.entrySet())
+            e.getKey().unsafeSetExpiration(e.getValue());
+        timeouts.clear();
+    }
+
+    @Test
+    public void testExpireInbound() throws Throwable
+    {
+        DatabaseDescriptor.setCrossNodeTimeout(true);
+        testOneManual((settings, inbound, outbound, endpoint, handler) -> {
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+
+            CountDownLatch connectionLatch = new CountDownLatch(1);
+            unsafeSetHandler(Verb._TEST_1, () -> v -> {
+                connectionLatch.countDown();
+            });
+            outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            connectionLatch.await(10, SECONDS);
+            Assert.assertEquals(0, connectionLatch.getCount());
+
+            // Slow things down
+            unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(50, MILLISECONDS));
+            handler.withLatency(100, MILLISECONDS);
+
+            unsafeSetHandler(Verb._TEST_1, () -> v -> {
+                throw new RuntimeException("Should have not been triggered " + v);
+            });
+            int expireMessages = 10;
+            for (int i = 0; i < expireMessages; i++)
+                outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+
+            InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint);
+            waitForCondition(() -> handlers.expiredCount() == expireMessages);
+            Assert.assertEquals(expireMessages, handlers.expiredCount());
+        });
+    }
+
+    @Test
+    public void testExpireSome() throws Throwable
+    {
+        DatabaseDescriptor.setCrossNodeTimeout(true);
+        testOneManual((settings, inbound, outbound, endpoint, handler) -> {
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+            connect(outbound);
+
+            AtomicInteger counter = new AtomicInteger();
+            unsafeSetHandler(Verb._TEST_1, () -> v -> {
+                counter.incrementAndGet();
+            });
+
+            int expireMessages = 10;
+            for (int i = 0; i < expireMessages; i++)
+                outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            waitForCondition(() -> counter.get() == 10);
+
+            unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(50, MILLISECONDS));
+            handler.withLatency(100, MILLISECONDS);
+
+            InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint);
+            for (int i = 0; i < expireMessages; i++)
+                outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            waitForCondition(() -> handlers.expiredCount() == 10);
+
+            handler.withLatency(2, MILLISECONDS);
+
+            for (int i = 0; i < expireMessages; i++)
+                outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            waitForCondition(() -> counter.get() == 20);
+        });
+    }
+
+    @Test
+    public void testExpireSomeFromBatch() throws Throwable
+    {
+        DatabaseDescriptor.setCrossNodeTimeout(true);
+        testManual((settings, inbound, outbound, endpoint, handler) -> {
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+            connect(outbound);
+
+            Message msg = Message.out(Verb._TEST_1, 1L);
+            int messageSize = msg.serializedSize(MessagingService.current_version);
+            DatabaseDescriptor.setInternodeMaxMessageSizeInBytes(messageSize * 40);
+
+            AtomicInteger counter = new AtomicInteger();
+            unsafeSetHandler(Verb._TEST_1, () -> v -> {
+                counter.incrementAndGet();
+            });
+
+            unsafeSetExpiration(Verb._TEST_1, unit -> unit.convert(200, MILLISECONDS));
+            handler.withLatency(100, MILLISECONDS);
+
+            int expireMessages = 20;
+            long nanoTime = approxTime.now();
+            CountDownLatch enqueueDone = new CountDownLatch(1);
+            outbound.unsafeRunOnDelivery(() -> Uninterruptibles.awaitUninterruptibly(enqueueDone, 10, SECONDS));
+            for (int i = 0; i < expireMessages; i++)
+            {
+                boolean expire = i % 2 == 0;
+                Message.Builder builder = Message.builder(Verb._TEST_1, 1L);
+
+                if (settings.right.acceptVersions == ConnectionTest.legacy)
+                {
+                    // backdate messages; leave 50 milliseconds to leave outbound path
+                    builder.withCreatedAt(nanoTime - (expire ? 0 : MILLISECONDS.toNanos(150)));
+                }
+                else
+                {
+                    // Give messages 50 milliseconds to leave outbound path
+                    builder.withCreatedAt(nanoTime)
+                           .withExpiresAt(nanoTime + (expire ? MILLISECONDS.toNanos(50) : MILLISECONDS.toNanos(1000)));
+                }
+                outbound.enqueue(builder.build());
+            }
+            enqueueDone.countDown();
+
+            InboundMessageHandlers handlers = MessagingService.instance().getInbound(endpoint);
+            waitForCondition(() -> handlers.expiredCount() == 10 && counter.get() == 10,
+                             () -> String.format("Expired: %d, Arrived: %d", handlers.expiredCount(), counter.get()));
+        });
+    }
+
+    @Test
+    public void suddenDisconnect() throws Throwable
+    {
+        testManual((settings, inbound, outbound, endpoint, handler) -> {
+            handler.onDisconnect(() -> handler.reset());
+
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+            connect(outbound);
+
+            CountDownLatch closeLatch = new CountDownLatch(1);
+            handler.withCloseAfterRead(closeLatch::countDown);
+            AtomicInteger counter = new AtomicInteger();
+            unsafeSetHandler(Verb._TEST_1, () -> v -> counter.incrementAndGet());
+
+            outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+            waitForCondition(() -> !outbound.isConnected());
+
+            connect(outbound);
+            Assert.assertTrue(outbound.isConnected());
+            Assert.assertEquals(0, counter.get());
+        });
+    }
+
+    @Test
+    public void testCorruptionOnHandshake() throws Throwable
+    {
+        testManual((settings, inbound, outbound, endpoint, handler) -> {
+            unsafeSetSerializer(Verb._TEST_1, FakePayloadSerializer::new);
+            // Invalid CRC
+            handler.withPayloadTransform(msg -> {
+                ByteBuf bb = (ByteBuf) msg;
+                bb.setByte(bb.readableBytes() / 2, 0xffff);
+                return msg;
+            });
+            tryConnect(outbound, 1, SECONDS, false);
+            Assert.assertTrue(!outbound.isConnected());
+
+            // Invalid protocol magic
+            handler.withPayloadTransform(msg -> {
+                ByteBuf bb = (ByteBuf) msg;
+                bb.setByte(0, 0xffff);
+                return msg;
+            });
+            tryConnect(outbound, 1, SECONDS, false);
+            Assert.assertTrue(!outbound.isConnected());
+            if (settings.right.framing == CRC)
+            {
+                Assert.assertEquals(2, outbound.connectionAttempts());
+                Assert.assertEquals(0, outbound.successfulConnections());
+            }
+        });
+    }
+
+    private static void waitForCondition(Supplier<Boolean> cond) throws Throwable
+    {
+        CompletableFuture.runAsync(() -> {
+            while (!cond.get()) {}
+        }).get(10, SECONDS);
+    }
+
+    private static void waitForCondition(Supplier<Boolean> cond, Supplier<String> s) throws Throwable
+    {
+        try
+        {
+            CompletableFuture.runAsync(() -> {
+                while (!cond.get()) {}
+            }).get(10, SECONDS);
+        }
+        catch (TimeoutException e)
+        {
+            throw new AssertionError(s.get());
+        }
+    }
+
+    private static class FakePayloadSerializer implements IVersionedSerializer<Long>
+    {
+        private final int size;
+        private FakePayloadSerializer()
+        {
+            this(1);
+        }
+
+        // Takes long and repeats it size times
+        private FakePayloadSerializer(int size)
+        {
+            this.size = size;
+        }
+
+        public void serialize(Long i, DataOutputPlus out, int version) throws IOException
+        {
+            for (int j = 0; j < size; j++)
+            {
+                out.writeLong(i);
+            }
+        }
+
+        public Long deserialize(DataInputPlus in, int version) throws IOException
+        {
+            long l = in.readLong();
+            for (int i = 0; i < size - 1; i++)
+            {
+                if (in.readLong() != l)
+                    throw new AssertionError();
+            }
+
+            return l;
+        }
+
+        public long serializedSize(Long t, int version)
+        {
+            return Long.BYTES * size;
+        }
+    }
+    interface ManualSendTest
+    {
+        void accept(Pair<InboundConnectionSettings, OutboundConnectionSettings> settings, InboundSockets inbound, OutboundConnection outbound, InetAddressAndPort endpoint, InboundProxyHandler.Controller handler) throws Throwable;
+    }
+
+    private void testManual(ManualSendTest test) throws Throwable
+    {
+        for (ConnectionTest.Settings s: SETTINGS)
+        {
+            doTestManual(s, test);
+            cleanup();
+        }
+    }
+
+    private void testOneManual(ManualSendTest test) throws Throwable
+    {
+        testOneManual(test, 1);
+    }
+
+    private void testOneManual(ManualSendTest test, int i) throws Throwable
+    {
+        ConnectionTest.Settings s = SETTINGS.get(i);
+        doTestManual(s, test);
+        cleanup();
+    }
+
+    private void doTestManual(ConnectionTest.Settings settings, ManualSendTest test) throws Throwable
+    {
+        InetAddressAndPort endpoint = FBUtilities.getBroadcastAddressAndPort();
+
+        InboundConnectionSettings inboundSettings = settings.inbound.apply(new InboundConnectionSettings())
+                                                                    .withBindAddress(endpoint)
+                                                                    .withSocketFactory(factory);
+
+        InboundSockets inbound = new InboundSockets(Collections.singletonList(inboundSettings));
+
+        OutboundConnectionSettings outboundSettings = settings.outbound.apply(new OutboundConnectionSettings(endpoint))
+                                                                       .withConnectTo(endpoint)
+                                                                       .withDefaultReserveLimits()
+                                                                       .withSocketFactory(factory);
+
+        ResourceLimits.EndpointAndGlobal reserveCapacityInBytes = new ResourceLimits.EndpointAndGlobal(new ResourceLimits.Concurrent(outboundSettings.applicationReserveSendQueueEndpointCapacityInBytes), outboundSettings.applicationReserveSendQueueGlobalCapacityInBytes);
+        OutboundConnection outbound = new OutboundConnection(settings.type, outboundSettings, reserveCapacityInBytes);
+        try
+        {
+            InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller();
+            inbound.open(pipeline -> {
+                InboundProxyHandler handler = new InboundProxyHandler(controller);
+                pipeline.addLast(handler);
+            }).sync();
+            test.accept(Pair.create(inboundSettings, outboundSettings), inbound, outbound, endpoint, controller);
+        }
+        finally
+        {
+            outbound.close(false);
+            inbound.close().get(30L, SECONDS);
+            outbound.close(false).get(30L, SECONDS);
+            MessagingService.instance().messageHandlers.clear();
+        }
+    }
+
+    private void connect(OutboundConnection outbound) throws Throwable
+    {
+        tryConnect(outbound, 10, SECONDS, true);
+    }
+
+    private void tryConnect(OutboundConnection outbound, long timeout, TimeUnit timeUnit, boolean throwOnFailure) throws Throwable
+    {
+        CountDownLatch connectionLatch = new CountDownLatch(1);
+        unsafeSetHandler(Verb._TEST_1, () -> v -> {
+            connectionLatch.countDown();
+        });
+        outbound.enqueue(Message.out(Verb._TEST_1, 1L));
+        connectionLatch.await(timeout, timeUnit);
+        if (throwOnFailure)
+            Assert.assertEquals(0, connectionLatch.getCount());
+    }
+}
diff --git a/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java b/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java
new file mode 100644
index 0000000..7e3b004
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/proxy/InboundProxyHandler.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.proxy;
+
+import java.util.ArrayDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.util.concurrent.EventExecutor;
+import io.netty.util.concurrent.ScheduledFuture;
+
+public class InboundProxyHandler extends ChannelInboundHandlerAdapter
+{
+    private final ArrayDeque<Forward> forwardQueue;
+    private ScheduledFuture scheduled = null;
+    private final Controller controller;
+    public InboundProxyHandler(Controller controller)
+    {
+        this.controller = controller;
+        this.forwardQueue = new ArrayDeque<>(1024);
+    }
+
+    @Override
+    public void channelActive(ChannelHandlerContext ctx) throws Exception
+    {
+        super.channelActive(ctx);
+        ctx.read();
+    }
+
+    @Override
+    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+        controller.onDisconnect.run();
+
+        if (scheduled != null)
+        {
+            scheduled.cancel(true);
+            scheduled = null;
+        }
+
+        if (!forwardQueue.isEmpty())
+            forwardQueue.clear();
+
+        super.channelInactive(ctx);
+    }
+
+
+    @Override
+    public void channelRead(ChannelHandlerContext ctx, Object msg)
+    {
+        Forward forward = controller.forwardStrategy.forward(ctx, msg);
+        forwardQueue.offer(forward);
+        maybeScheduleNext(ctx.channel().eventLoop());
+        controller.onRead.run();
+        ctx.channel().read();
+    }
+
+    private void maybeScheduleNext(EventExecutor executor)
+    {
+        if (forwardQueue.isEmpty())
+        {
+            // Ran out of items to process
+            scheduled = null;
+        }
+        else if (scheduled == null)
+        {
+            // Schedule next available or let the last in line schedule it
+            Forward forward = forwardQueue.poll();
+            scheduled = forward.schedule(executor);
+            scheduled.addListener((e) -> {
+                scheduled = null;
+                maybeScheduleNext(executor);
+            });
+        }
+    }
+
+    private static class Forward
+    {
+        final long arrivedAt;
+        final long latency;
+        final Runnable handler;
+
+        private Forward(long arrivedAt, long latency, Runnable handler)
+        {
+            this.arrivedAt = arrivedAt;
+            this.latency = latency;
+            this.handler = handler;
+        }
+
+        ScheduledFuture schedule(EventExecutor executor)
+        {
+            long now = System.currentTimeMillis();
+            long elapsed = now - arrivedAt;
+            long runIn = latency - elapsed;
+
+            if (runIn > 0)
+                return executor.schedule(handler, runIn, TimeUnit.MILLISECONDS);
+            else
+                return executor.schedule(handler, 0, TimeUnit.MILLISECONDS);
+        }
+    }
+
+    private static class ForwardNormally implements ForwardStrategy
+    {
+        static ForwardNormally instance = new ForwardNormally();
+
+        public Forward forward(ChannelHandlerContext ctx, Object msg)
+        {
+            return new Forward(System.currentTimeMillis(),
+                               0,
+                               () -> ctx.fireChannelRead(msg));
+        }
+    }
+
+    public interface ForwardStrategy
+    {
+        public Forward forward(ChannelHandlerContext ctx, Object msg);
+    }
+
+    private static class ForwardWithLatency implements ForwardStrategy
+    {
+        private final long latency;
+        private final TimeUnit timeUnit;
+
+        ForwardWithLatency(long latency, TimeUnit timeUnit)
+        {
+            this.latency = latency;
+            this.timeUnit = timeUnit;
+        }
+
+        public Forward forward(ChannelHandlerContext ctx, Object msg)
+        {
+            return new Forward(System.currentTimeMillis(),
+                               timeUnit.toMillis(latency),
+                               () -> ctx.fireChannelRead(msg));
+        }
+    }
+
+    private static class CloseAfterRead implements ForwardStrategy
+    {
+        private final Runnable afterClose;
+
+        CloseAfterRead(Runnable afterClose)
+        {
+            this.afterClose = afterClose;
+        }
+
+        public Forward forward(ChannelHandlerContext ctx, Object msg)
+        {
+            return  new Forward(System.currentTimeMillis(),
+                                0,
+                                () -> {
+                                    ctx.channel().close().syncUninterruptibly();
+                                    afterClose.run();
+                                });
+        }
+    }
+
+    private static class TransformPayload<T> implements ForwardStrategy
+    {
+        private final Function<T, T> fn;
+
+        TransformPayload(Function<T, T> fn)
+        {
+            this.fn = fn;
+        }
+
+        public Forward forward(ChannelHandlerContext ctx, Object msg)
+        {
+            return new Forward(System.currentTimeMillis(),
+                               0,
+                               () -> ctx.fireChannelRead(fn.apply((T) msg)));
+        }
+    }
+
+    public static class Controller
+    {
+        private volatile InboundProxyHandler.ForwardStrategy forwardStrategy;
+        private volatile Runnable onRead = () -> {};
+        private volatile Runnable onDisconnect = () -> {};
+
+        public Controller()
+        {
+            this.forwardStrategy = ForwardNormally.instance;
+        }
+        public void onRead(Runnable onRead)
+        {
+            this.onRead = onRead;
+        }
+
+        public void onDisconnect(Runnable onDisconnect)
+        {
+            this.onDisconnect = onDisconnect;
+        }
+
+        public void reset()
+        {
+            this.forwardStrategy = ForwardNormally.instance;
+        }
+
+        public void withLatency(long latency, TimeUnit timeUnit)
+        {
+            this.forwardStrategy = new ForwardWithLatency(latency, timeUnit);
+        }
+
+        public void withCloseAfterRead(Runnable afterClose)
+        {
+            this.forwardStrategy = new CloseAfterRead(afterClose);
+        }
+
+        public <T> void withPayloadTransform(Function<T, T> fn)
+        {
+            this.forwardStrategy = new TransformPayload<>(fn);
+        }
+    }
+
+}
diff --git a/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java b/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java
new file mode 100644
index 0000000..d070f56
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/proxy/ProxyHandlerTest.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net.proxy;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.local.LocalAddress;
+import io.netty.channel.local.LocalChannel;
+import io.netty.channel.local.LocalServerChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.handler.logging.LogLevel;
+import io.netty.handler.logging.LoggingHandler;
+
+public class ProxyHandlerTest
+{
+    private final Object PAYLOAD = new Object();
+
+    @Test
+    public void testLatency() throws Throwable
+    {
+        test((proxyHandler, testHandler, channel) -> {
+            int count = 1;
+            long latency = 100;
+            CountDownLatch latch = new CountDownLatch(count);
+            long start = System.currentTimeMillis();
+            testHandler.onRead = new Consumer<Object>()
+            {
+                int last = -1;
+                public void accept(Object o)
+                {
+                    // Make sure that order is preserved
+                    Assert.assertEquals(last + 1, o);
+                    last = (int) o;
+
+                    long elapsed = System.currentTimeMillis() - start;
+                    Assert.assertTrue("Latency was:" + elapsed, elapsed > latency);
+                    latch.countDown();
+                }
+            };
+
+            proxyHandler.withLatency(latency, TimeUnit.MILLISECONDS);
+
+            for (int i = 0; i < count; i++)
+            {
+                ByteBuf bb = Unpooled.buffer(Integer.BYTES);
+                bb.writeInt(i);
+                channel.writeAndFlush(i);
+            }
+
+            Assert.assertTrue(latch.await(10, TimeUnit.SECONDS));
+        });
+    }
+
+    @Test
+    public void testNormalDelivery() throws Throwable
+    {
+        test((proxyHandler, testHandler, channelPipeline) -> {
+            int count = 10;
+            CountDownLatch latch = new CountDownLatch(count);
+            AtomicLong end = new AtomicLong();
+            testHandler.onRead = (o) -> {
+                end.set(System.currentTimeMillis());
+                latch.countDown();
+            };
+
+            for (int i = 0; i < count; i++)
+                channelPipeline.writeAndFlush(PAYLOAD);
+            Assert.assertTrue(latch.await(10, TimeUnit.SECONDS));
+
+        });
+    }
+
+    @Test
+    public void testLatencyForMany() throws Throwable
+    {
+        class Event {
+            private final long latency;
+            private final long start;
+            private final int idx;
+
+            Event(long latency, int idx)
+            {
+                this.latency = latency;
+                this.start = System.currentTimeMillis();
+                this.idx = idx;
+            }
+        }
+
+        test((proxyHandler, testHandler, channel) -> {
+            int count = 150;
+            CountDownLatch latch = new CountDownLatch(count);
+            AtomicInteger counter = new AtomicInteger();
+            testHandler.onRead = new Consumer<Object>()
+            {
+                int lastSeen = -1;
+                public void accept(Object o)
+                {
+                    Event e = (Event) o;
+                    Assert.assertEquals(lastSeen + 1, e.idx);
+                    lastSeen = e.idx;
+                    long elapsed = System.currentTimeMillis() - e.start;
+                    Assert.assertTrue(elapsed >= e.latency);
+                    counter.incrementAndGet();
+                    latch.countDown();
+                }
+            };
+
+            int idx = 0;
+            for (int i = 0; i < count / 3; i++)
+            {
+                for (long latency : new long[]{ 100, 200, 0 })
+                {
+                    proxyHandler.withLatency(latency, TimeUnit.MILLISECONDS);
+                    CountDownLatch read = new CountDownLatch(1);
+                    proxyHandler.onRead(read::countDown);
+                    channel.writeAndFlush(new Event(latency, idx++));
+                    Assert.assertTrue(read.await(10, TimeUnit.SECONDS));
+                }
+            }
+
+            Assert.assertTrue(latch.await(10, TimeUnit.SECONDS));
+            Assert.assertEquals(counter.get(), count);
+        });
+    }
+
+    private interface DoTest
+    {
+        public void doTest(InboundProxyHandler.Controller proxy, TestHandler testHandler, Channel channel) throws Throwable;
+    }
+
+
+    public void test(DoTest test) throws Throwable
+    {
+        EventLoopGroup serverGroup = new NioEventLoopGroup(1);
+        EventLoopGroup clientGroup = new NioEventLoopGroup(1);
+
+        InboundProxyHandler.Controller controller = new InboundProxyHandler.Controller();
+        InboundProxyHandler proxyHandler = new InboundProxyHandler(controller);
+        TestHandler testHandler = new TestHandler();
+
+        ServerBootstrap sb = new ServerBootstrap();
+        sb.group(serverGroup)
+          .channel(LocalServerChannel.class)
+          .childHandler(new ChannelInitializer<LocalChannel>() {
+              @Override
+              public void initChannel(LocalChannel ch)
+              {
+                  ch.pipeline()
+                    .addLast(proxyHandler)
+                    .addLast(testHandler);
+              }
+          })
+          .childOption(ChannelOption.AUTO_READ, false);
+
+        Bootstrap cb = new Bootstrap();
+        cb.group(clientGroup)
+          .channel(LocalChannel.class)
+          .handler(new ChannelInitializer<LocalChannel>() {
+              @Override
+              public void initChannel(LocalChannel ch) throws Exception {
+                  ch.pipeline()
+                    .addLast(new LoggingHandler(LogLevel.TRACE));
+              }
+          });
+
+        final LocalAddress addr = new LocalAddress("test");
+
+        Channel serverChannel = sb.bind(addr).sync().channel();
+
+        Channel clientChannel = cb.connect(addr).sync().channel();
+        test.doTest(controller, testHandler, clientChannel);
+
+        clientChannel.close();
+        serverChannel.close();
+        serverGroup.shutdownGracefully();
+        clientGroup.shutdownGracefully();
+    }
+
+
+    public static class TestHandler extends ChannelInboundHandlerAdapter
+    {
+        private Consumer<Object> onRead = (o) -> {};
+        @Override
+        public void channelRead(ChannelHandlerContext ctx, Object msg)
+        {
+            onRead.accept(msg);
+        }
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org