You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by if...@apache.org on 2020/01/31 13:44:51 UTC

[cassandra] branch cassandra-2.2 updated: Add message interceptors to in-jvm dtests

This is an automated email from the ASF dual-hosted git repository.

ifesdjeen pushed a commit to branch cassandra-2.2
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/cassandra-2.2 by this push:
     new b2f2c70  Add message interceptors to in-jvm dtests
b2f2c70 is described below

commit b2f2c70e26a32253b0e58ee197c2f8abf01dd449
Author: Alex Petrov <ol...@gmail.com>
AuthorDate: Wed Jan 15 17:18:46 2020 +0100

    Add message interceptors to in-jvm dtests
    
    Patch by Alex Petrov; reviewed by Yifan Cai and David Capwell for CASSANDRA-15505.
---
 .../apache/cassandra/distributed/api/IMessage.java |   8 +-
 .../cassandra/distributed/api/IMessageFilters.java |  28 ++-
 .../distributed/impl/AbstractCluster.java          |  17 +-
 .../distributed/impl/IInvokableInstance.java       |   1 -
 .../cassandra/distributed/impl/Instance.java       | 187 ++++++++++--------
 .../cassandra/distributed/impl/MessageFilters.java |  76 ++++----
 .../distributed/test/MessageFiltersTest.java       | 210 +++++++++++++++++++++
 7 files changed, 391 insertions(+), 136 deletions(-)

diff --git a/test/distributed/org/apache/cassandra/distributed/api/IMessage.java b/test/distributed/org/apache/cassandra/distributed/api/IMessage.java
index 1e537ed..cd98543 100644
--- a/test/distributed/org/apache/cassandra/distributed/api/IMessage.java
+++ b/test/distributed/org/apache/cassandra/distributed/api/IMessage.java
@@ -18,12 +18,16 @@
 
 package org.apache.cassandra.distributed.api;
 
+import java.io.Serializable;
+
 import org.apache.cassandra.locator.InetAddressAndPort;
 
 /**
- * A cross-version interface for delivering internode messages via message sinks
+ * A cross-version interface for delivering internode messages via message sinks.
+ *
+ * Message implementations should be serializable so we could load into instances.
  */
-public interface IMessage
+public interface IMessage extends Serializable
 {
     int verb();
     byte[] bytes();
diff --git a/test/distributed/org/apache/cassandra/distributed/api/IMessageFilters.java b/test/distributed/org/apache/cassandra/distributed/api/IMessageFilters.java
index f7c8094..01fe972 100644
--- a/test/distributed/org/apache/cassandra/distributed/api/IMessageFilters.java
+++ b/test/distributed/org/apache/cassandra/distributed/api/IMessageFilters.java
@@ -18,29 +18,39 @@
 
 package org.apache.cassandra.distributed.api;
 
-import org.apache.cassandra.net.MessagingService;
-
 public interface IMessageFilters
 {
     public interface Filter
     {
-        Filter restore();
-        Filter drop();
+        Filter off();
+        Filter on();
     }
 
     public interface Builder
     {
         Builder from(int ... nums);
         Builder to(int ... nums);
-        Filter ready();
+
+        /**
+         * Every message for which matcher returns `true` will be _dropped_ (assuming all
+         * other matchers in the chain will return `true` as well).
+         */
+        Builder messagesMatching(Matcher filter);
         Filter drop();
     }
 
-    Builder verbs(MessagingService.Verb... verbs);
+    public interface Matcher
+    {
+        boolean matches(int from, int to, IMessage message);
+    }
+
+    Builder verbs(int... verbs);
     Builder allVerbs();
     void reset();
 
-    // internal
-    boolean permit(IInstance from, IInstance to, int verb);
-
+    /**
+     * {@code true} value returned by the implementation implies that the message was
+     * not matched by any filters and therefore should be delivered.
+     */
+    boolean permit(int from, int to, IMessage msg);
 }
diff --git a/test/distributed/org/apache/cassandra/distributed/impl/AbstractCluster.java b/test/distributed/org/apache/cassandra/distributed/impl/AbstractCluster.java
index 45e9919..1ee0c14 100644
--- a/test/distributed/org/apache/cassandra/distributed/impl/AbstractCluster.java
+++ b/test/distributed/org/apache/cassandra/distributed/impl/AbstractCluster.java
@@ -39,7 +39,6 @@ import java.util.stream.IntStream;
 import java.util.stream.Stream;
 
 import com.google.common.collect.Sets;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -48,6 +47,7 @@ import org.apache.cassandra.db.ColumnFamilyStore;
 import org.apache.cassandra.db.ConsistencyLevel;
 import org.apache.cassandra.db.Keyspace;
 import org.apache.cassandra.distributed.api.Feature;
+import org.apache.cassandra.distributed.api.ICluster;
 import org.apache.cassandra.distributed.api.ICoordinator;
 import org.apache.cassandra.distributed.api.IInstance;
 import org.apache.cassandra.distributed.api.IInstanceConfig;
@@ -55,7 +55,6 @@ import org.apache.cassandra.distributed.api.IIsolatedExecutor;
 import org.apache.cassandra.distributed.api.IListen;
 import org.apache.cassandra.distributed.api.IMessage;
 import org.apache.cassandra.distributed.api.IMessageFilters;
-import org.apache.cassandra.distributed.api.ICluster;
 import org.apache.cassandra.io.util.FileUtils;
 import org.apache.cassandra.locator.InetAddressAndPort;
 import org.apache.cassandra.net.MessagingService;
@@ -282,8 +281,18 @@ public abstract class AbstractCluster<I extends IInstance> implements ICluster,
                                   timeout, unit);
     }
 
-    public IMessageFilters filters() { return filters; }
-    public MessageFilters.Builder verbs(MessagingService.Verb ... verbs) { return filters.verbs(verbs); }
+    public IMessageFilters filters()
+    {
+        return filters;
+    }
+
+    public MessageFilters.Builder verbs(MessagingService.Verb... verbs)
+    {
+        int[] ids = new int[verbs.length];
+        for (int i = 0; i < verbs.length; ++i)
+            ids[i] = verbs[i].ordinal();
+        return filters.verbs(ids);
+    }
 
     public void disableAutoCompaction(String keyspace)
     {
diff --git a/test/distributed/org/apache/cassandra/distributed/impl/IInvokableInstance.java b/test/distributed/org/apache/cassandra/distributed/impl/IInvokableInstance.java
index 6fe5891..d0f8a5f 100644
--- a/test/distributed/org/apache/cassandra/distributed/impl/IInvokableInstance.java
+++ b/test/distributed/org/apache/cassandra/distributed/impl/IInvokableInstance.java
@@ -26,7 +26,6 @@ import java.util.function.Consumer;
 import java.util.function.Function;
 
 import org.apache.cassandra.distributed.api.IInstance;
-import org.apache.cassandra.distributed.api.IIsolatedExecutor;
 
 /**
  * This version is only supported for a Cluster running the same code as the test environment, and permits
diff --git a/test/distributed/org/apache/cassandra/distributed/impl/Instance.java b/test/distributed/org/apache/cassandra/distributed/impl/Instance.java
index 46f2d23..0647198 100644
--- a/test/distributed/org/apache/cassandra/distributed/impl/Instance.java
+++ b/test/distributed/org/apache/cassandra/distributed/impl/Instance.java
@@ -86,9 +86,11 @@ import org.apache.cassandra.tracing.Tracing;
 import org.apache.cassandra.transport.messages.ResultMessage;
 import org.apache.cassandra.utils.FBUtilities;
 import org.apache.cassandra.utils.NanoTimeToCurrentTimeMillis;
+import org.apache.cassandra.utils.Pair;
 import org.apache.cassandra.utils.Throwables;
 import org.apache.cassandra.utils.UUIDGen;
 import org.apache.cassandra.utils.concurrent.Ref;
+import org.w3c.dom.events.UIEvent;
 
 import static java.util.concurrent.TimeUnit.MINUTES;
 import static org.apache.cassandra.distributed.api.Feature.GOSSIP;
@@ -190,7 +192,9 @@ public class Instance extends IsolatedExecutor implements IInvokableInstance
     {
         BiConsumer<InetAddressAndPort, IMessage> deliverToInstance = (to, message) -> cluster.get(to).receiveMessage(message);
         BiConsumer<InetAddressAndPort, IMessage> deliverToInstanceIfNotFiltered = (to, message) -> {
-            if (cluster.filters().permit(this, cluster.get(to), message.verb()))
+            int fromNum = config().num();
+            int toNum = cluster.get(to).config().num();
+            if (cluster.filters().permit(fromNum, toNum, message))
                 deliverToInstance.accept(to, message);
         };
 
@@ -218,7 +222,9 @@ public class Instance extends IsolatedExecutor implements IInvokableInstance
             {
                 // Port is not passed in, so take a best guess at the destination port from this instance
                 IInstance to = cluster.get(InetAddressAndPort.getByAddressOverrideDefaults(toAddress, instance.config().broadcastAddressAndPort().port));
-                return cluster.filters().permit(instance, to, message.verb.ordinal());
+                int fromNum = config().num();
+                int toNum = to.config().num();
+                return cluster.filters().permit(fromNum, toNum, serializeMessage(message, id, broadcastAddressAndPort(), to.broadcastAddressAndPort()));
             }
 
             public boolean allowIncomingMessage(MessageIn message, int id)
@@ -228,6 +234,25 @@ public class Instance extends IsolatedExecutor implements IInvokableInstance
         });
     }
 
+    public static IMessage serializeMessage(MessageOut messageOut, int id, InetAddressAndPort from, InetAddressAndPort to)
+    {
+        try (DataOutputBuffer out = new DataOutputBuffer(1024))
+        {
+            int version = MessagingService.instance().getVersion(to.address);
+
+            out.writeInt(MessagingService.PROTOCOL_MAGIC);
+            out.writeInt(id);
+            long timestamp = System.currentTimeMillis();
+            out.writeInt((int) timestamp);
+            messageOut.serialize(out, version);
+            return new Message(messageOut.verb.ordinal(), out.toByteArray(), id, version, from);
+        }
+        catch (IOException e)
+        {
+            throw new RuntimeException(e);
+        }
+    }
+
     private class MessageDeliverySink implements IMessageSink
     {
         private final BiConsumer<InetAddressAndPort, IMessage> deliver;
@@ -240,46 +265,35 @@ public class Instance extends IsolatedExecutor implements IInvokableInstance
 
         public boolean allowOutgoingMessage(MessageOut messageOut, int id, InetAddress to)
         {
-            try (DataOutputBuffer out = new DataOutputBuffer(1024))
+            InetAddressAndPort from = broadcastAddressAndPort();
+            InetAddressAndPort toFull = lookupAddressAndPort.apply(to);
+            assert from.equals(lookupAddressAndPort.apply(messageOut.from));
+
+            IMessage serialized = serializeMessage(messageOut, id, broadcastAddressAndPort(), lookupAddressAndPort.apply(messageOut.from));
+
+            // Tracing logic - similar to org.apache.cassandra.net.OutboundTcpConnection.writeConnected
+            byte[] sessionBytes = (byte[]) messageOut.parameters.get(Tracing.TRACE_HEADER);
+            if (sessionBytes != null)
             {
-                InetAddressAndPort from = broadcastAddressAndPort();
-                assert from.equals(lookupAddressAndPort.apply(messageOut.from));
-                InetAddressAndPort toFull = lookupAddressAndPort.apply(to);
-                int version = MessagingService.instance().getVersion(to);
-
-                // Tracing logic - similar to org.apache.cassandra.net.OutboundTcpConnection.writeConnected
-                byte[] sessionBytes = (byte[]) messageOut.parameters.get(Tracing.TRACE_HEADER);
-                if (sessionBytes != null)
+                UUID sessionId = UUIDGen.getUUID(ByteBuffer.wrap(sessionBytes));
+                TraceState state = Tracing.instance.get(sessionId);
+                String message = String.format("Sending %s message to %s", messageOut.verb, to);
+                // session may have already finished; see CASSANDRA-5668
+                if (state == null)
                 {
-                    UUID sessionId = UUIDGen.getUUID(ByteBuffer.wrap(sessionBytes));
-                    TraceState state = Tracing.instance.get(sessionId);
-                    String message = String.format("Sending %s message to %s", messageOut.verb, toFull.address);
-                    // session may have already finished; see CASSANDRA-5668
-                    if (state == null)
-                    {
-                        byte[] traceTypeBytes = (byte[]) messageOut.parameters.get(Tracing.TRACE_TYPE);
-                        Tracing.TraceType traceType = traceTypeBytes == null ? Tracing.TraceType.QUERY : Tracing.TraceType.deserialize(traceTypeBytes[0]);
-                        TraceState.mutateWithTracing(ByteBuffer.wrap(sessionBytes), message, -1, traceType.getTTL());
-                    }
-                    else
-                    {
-                        state.trace(message);
-                        if (messageOut.verb == MessagingService.Verb.REQUEST_RESPONSE)
-                            Tracing.instance.doneWithNonLocalSession(state);
-                    }
+                    byte[] traceTypeBytes = (byte[]) messageOut.parameters.get(Tracing.TRACE_TYPE);
+                    Tracing.TraceType traceType = traceTypeBytes == null ? Tracing.TraceType.QUERY : Tracing.TraceType.deserialize(traceTypeBytes[0]);
+                    TraceState.mutateWithTracing(ByteBuffer.wrap(sessionBytes), message, -1, traceType.getTTL());
+                }
+                else
+                {
+                    state.trace(message);
+                    if (messageOut.verb == MessagingService.Verb.REQUEST_RESPONSE)
+                        Tracing.instance.doneWithNonLocalSession(state);
                 }
-
-                out.writeInt(MessagingService.PROTOCOL_MAGIC);
-                out.writeInt(id);
-                long timestamp = System.currentTimeMillis();
-                out.writeInt((int) timestamp);
-                messageOut.serialize(out, version);
-                deliver.accept(toFull, new Message(messageOut.verb.ordinal(), out.toByteArray(), id, version, from));
-            }
-            catch (IOException e)
-            {
-                throw new RuntimeException(e);
             }
+
+            deliver.accept(toFull, serialized);
             return false;
         }
 
@@ -290,57 +304,74 @@ public class Instance extends IsolatedExecutor implements IInvokableInstance
         }
     }
 
-    public void receiveMessage(IMessage imessage)
+    public static Pair<MessageIn<Object>, Integer> deserializeMessage(IMessage msg)
     {
-        sync(() -> {
-            // Based on org.apache.cassandra.net.IncomingTcpConnection.receiveMessage
-            try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(imessage.bytes())))
+        // Based on org.apache.cassandra.net.IncomingTcpConnection.receiveMessage
+        try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(msg.bytes())))
+        {
+            int version = msg.version();
+            if (version > MessagingService.current_version)
             {
-                int version = imessage.version();
-                if (version > MessagingService.current_version)
-                {
-                    throw new IllegalStateException(String.format("Node%d received message version %d but current version is %d",
-                                                                  this.config.num(),
-                                                                  version,
-                                                                  MessagingService.current_version));
-                }
+                throw new IllegalStateException(String.format("Received message version %d but current version is %d",
+                                                              version,
+                                                              MessagingService.current_version));
+            }
 
-                MessagingService.validateMagic(input.readInt());
-                int id;
-                if (version < MessagingService.VERSION_20)
-                    id = Integer.parseInt(input.readUTF());
-                else
-                    id = input.readInt();
-                assert imessage.id() == id;
+            MessagingService.validateMagic(input.readInt());
+            int id;
+            if (version < MessagingService.VERSION_20)
+                id = Integer.parseInt(input.readUTF());
+            else
+                id = input.readInt();
+            if (msg.id() != id)
+                throw new IllegalStateException(String.format("Message id mismatch: %d != %d", msg.id(), id));
 
-                long timestamp = System.currentTimeMillis();
-                boolean isCrossNodeTimestamp = false;
+            // make sure to readInt, even if cross_node_to is not enabled
+            int partial = input.readInt();
 
-                // make sure to readInt, even if cross_node_to is not enabled
-                int partial = input.readInt();
-                if (DatabaseDescriptor.hasCrossNodeTimeout())
-                {
-                    long crossNodeTimestamp = (timestamp & 0xFFFFFFFF00000000L) | (((partial & 0xFFFFFFFFL) << 2) >> 2);
-                    isCrossNodeTimestamp = (timestamp != crossNodeTimestamp);
-                    timestamp = crossNodeTimestamp;
-                }
+            return Pair.create(MessageIn.read(input, version, id), partial);
+        }
+        catch (IOException e)
+        {
+            throw new RuntimeException();
+        }
+    }
 
-                MessageIn message = MessageIn.read(input, version, id);
-                if (message == null)
-                {
-                    // callback expired; nothing to do
-                    return;
-                }
-                if (version <= MessagingService.current_version)
-                {
-                    MessagingService.instance().receive(message, id, timestamp, isCrossNodeTimestamp);
-                }
-                // else ignore message
+    public void receiveMessage(IMessage imessage)
+    {
+        sync(() -> {
+            Pair<MessageIn<Object>, Integer> deserialized = null;
+            try
+            {
+                deserialized = deserializeMessage(imessage);
             }
             catch (Throwable t)
             {
                 throw new RuntimeException("Exception occurred on node " + broadcastAddressAndPort(), t);
             }
+
+            MessageIn<Object> message = deserialized.left;
+            int partial = deserialized.right;
+
+            long timestamp = System.currentTimeMillis();
+            boolean isCrossNodeTimestamp = false;
+
+            if (DatabaseDescriptor.hasCrossNodeTimeout())
+            {
+                long crossNodeTimestamp = (timestamp & 0xFFFFFFFF00000000L) | (((partial & 0xFFFFFFFFL) << 2) >> 2);
+                isCrossNodeTimestamp = (timestamp != crossNodeTimestamp);
+                timestamp = crossNodeTimestamp;
+            }
+
+            if (message == null)
+            {
+                // callback expired; nothing to do
+                return;
+            }
+            if (message.version <= MessagingService.current_version)
+            {
+                MessagingService.instance().receive(message, imessage.id(), timestamp, isCrossNodeTimestamp);
+            }
         }).run();
     }
 
diff --git a/test/distributed/org/apache/cassandra/distributed/impl/MessageFilters.java b/test/distributed/org/apache/cassandra/distributed/impl/MessageFilters.java
index c1607f8..c92553f 100644
--- a/test/distributed/org/apache/cassandra/distributed/impl/MessageFilters.java
+++ b/test/distributed/org/apache/cassandra/distributed/impl/MessageFilters.java
@@ -19,32 +19,23 @@
 package org.apache.cassandra.distributed.impl;
 
 import java.util.Arrays;
-import java.util.Set;
-import java.util.concurrent.CopyOnWriteArraySet;
-import java.util.function.BiConsumer;
+import java.util.List;
+import java.util.concurrent.CopyOnWriteArrayList;
 
-import org.apache.cassandra.distributed.api.IInstance;
 import org.apache.cassandra.distributed.api.IMessage;
 import org.apache.cassandra.distributed.api.IMessageFilters;
-import org.apache.cassandra.distributed.api.ICluster;
-import org.apache.cassandra.locator.InetAddressAndPort;
-import org.apache.cassandra.net.MessagingService;
 
 public class MessageFilters implements IMessageFilters
 {
-    private final Set<Filter> filters = new CopyOnWriteArraySet<>();
+    private final List<Filter> filters = new CopyOnWriteArrayList<>();
 
-    public boolean permit(IInstance from, IInstance to, int verb)
+    public boolean permit(int from, int to, IMessage msg)
     {
-        if (from == null || to == null)
-            return false; // cannot deliver
-        int fromNum = from.config().num();
-        int toNum = to.config().num();
-
         for (Filter filter : filters)
-            if (filter.matches(fromNum, toNum, verb))
+        {
+            if (filter.matches(from, to, msg))
                 return false;
-
+        }
         return true;
     }
 
@@ -53,8 +44,9 @@ public class MessageFilters implements IMessageFilters
         final int[] from;
         final int[] to;
         final int[] verbs;
+        final Matcher matcher;
 
-        Filter(int[] from, int[] to, int[] verbs)
+        Filter(int[] from, int[] to, int[] verbs, Matcher matcher)
         {
             if (from != null)
             {
@@ -74,13 +66,14 @@ public class MessageFilters implements IMessageFilters
             this.from = from;
             this.to = to;
             this.verbs = verbs;
+            this.matcher = matcher;
         }
 
         public int hashCode()
         {
             return (from == null ? 0 : Arrays.hashCode(from))
-                    + (to == null ? 0 : Arrays.hashCode(to))
-                    + (verbs == null ? 0 : Arrays.hashCode(verbs));
+                   + (to == null ? 0 : Arrays.hashCode(to))
+                   + (verbs == null ? 0 : Arrays.hashCode(verbs));
         }
 
         public boolean equals(Object that)
@@ -91,28 +84,29 @@ public class MessageFilters implements IMessageFilters
         public boolean equals(Filter that)
         {
             return Arrays.equals(from, that.from)
-                    && Arrays.equals(to, that.to)
-                    && Arrays.equals(verbs, that.verbs);
+                   && Arrays.equals(to, that.to)
+                   && Arrays.equals(verbs, that.verbs);
         }
 
-        public boolean matches(int from, int to, int verb)
-        {
-            return (this.from == null || Arrays.binarySearch(this.from, from) >= 0)
-                    && (this.to == null || Arrays.binarySearch(this.to, to) >= 0)
-                    && (this.verbs == null || Arrays.binarySearch(this.verbs, verb) >= 0);
-        }
-
-        public Filter restore()
+        public Filter off()
         {
             filters.remove(this);
             return this;
         }
 
-        public Filter drop()
+        public Filter on()
         {
             filters.add(this);
             return this;
         }
+
+        public boolean matches(int from, int to, IMessage msg)
+        {
+            return (this.from == null || Arrays.binarySearch(this.from, from) >= 0)
+                   && (this.to == null || Arrays.binarySearch(this.to, to) >= 0)
+                   && (this.verbs == null || Arrays.binarySearch(this.verbs, msg.verb()) >= 0)
+                   && (this.matcher == null || this.matcher.matches(from, to, msg));
+        }
     }
 
     public class Builder implements IMessageFilters.Builder
@@ -120,42 +114,41 @@ public class MessageFilters implements IMessageFilters
         int[] from;
         int[] to;
         int[] verbs;
+        Matcher matcher;
 
         private Builder(int[] verbs)
         {
             this.verbs = verbs;
         }
 
-        public Builder from(int ... nums)
+        public Builder from(int... nums)
         {
             from = nums;
             return this;
         }
 
-        public Builder to(int ... nums)
+        public Builder to(int... nums)
         {
             to = nums;
             return this;
         }
 
-        public Filter ready()
+        public IMessageFilters.Builder messagesMatching(Matcher matcher)
         {
-            return new Filter(from, to, verbs);
+            this.matcher = matcher;
+            return this;
         }
 
         public Filter drop()
         {
-            return ready().drop();
+            return new Filter(from, to, verbs, matcher).on();
         }
     }
 
-    @Override
-    public Builder verbs(MessagingService.Verb... verbs)
+
+    public Builder verbs(int... verbs)
     {
-        int[] ids = new int[verbs.length];
-        for (int i = 0 ; i < verbs.length ; ++i)
-            ids[i] = verbs[i].ordinal();
-        return new Builder(ids);
+        return new Builder(verbs);
     }
 
     @Override
@@ -169,5 +162,4 @@ public class MessageFilters implements IMessageFilters
     {
         filters.clear();
     }
-
 }
diff --git a/test/distributed/org/apache/cassandra/distributed/test/MessageFiltersTest.java b/test/distributed/org/apache/cassandra/distributed/test/MessageFiltersTest.java
new file mode 100644
index 0000000..96974d8
--- /dev/null
+++ b/test/distributed/org/apache/cassandra/distributed/test/MessageFiltersTest.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.distributed.test;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.cassandra.db.ConsistencyLevel;
+import org.apache.cassandra.distributed.Cluster;
+import org.apache.cassandra.distributed.api.IIsolatedExecutor;
+import org.apache.cassandra.distributed.api.IMessage;
+import org.apache.cassandra.distributed.api.IMessageFilters;
+import org.apache.cassandra.distributed.impl.Instance;
+import org.apache.cassandra.distributed.impl.MessageFilters;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.net.MessageIn;
+import org.apache.cassandra.net.MessagingService;
+
+public class MessageFiltersTest extends DistributedTestBase
+{
+    @Test
+    public void simpleFiltersTest() throws Throwable
+    {
+        int VERB1 = MessagingService.Verb.READ.ordinal();
+        int VERB2 = MessagingService.Verb.REQUEST_RESPONSE.ordinal();
+        int VERB3 = MessagingService.Verb.READ_REPAIR.ordinal();
+        int i1 = 1;
+        int i2 = 2;
+        int i3 = 3;
+        String MSG1 = "msg1";
+        String MSG2 = "msg2";
+
+        MessageFilters filters = new MessageFilters();
+        MessageFilters.Filter filter = filters.allVerbs().from(1).drop();
+
+        Assert.assertFalse(filters.permit(i1, i2, msg(VERB1, MSG1)));
+        Assert.assertFalse(filters.permit(i1, i2, msg(VERB2, MSG1)));
+        Assert.assertFalse(filters.permit(i1, i2, msg(VERB3, MSG1)));
+        Assert.assertTrue(filters.permit(i2, i1, msg(VERB1, MSG1)));
+        filter.off();
+        Assert.assertTrue(filters.permit(i1, i2, msg(VERB1, MSG1)));
+        filters.reset();
+
+        filters.verbs(VERB1).from(1).to(2).drop();
+        Assert.assertFalse(filters.permit(i1, i2, msg(VERB1, MSG1)));
+        Assert.assertTrue(filters.permit(i1, i2, msg(VERB2, MSG1)));
+        Assert.assertTrue(filters.permit(i2, i1, msg(VERB1, MSG1)));
+        Assert.assertTrue(filters.permit(i2, i3, msg(VERB2, MSG1)));
+
+        filters.reset();
+        AtomicInteger counter = new AtomicInteger();
+        filters.verbs(VERB1).from(1).to(2).messagesMatching((from, to, msg) -> {
+            counter.incrementAndGet();
+            return Arrays.equals(msg.bytes(), MSG1.getBytes());
+        }).drop();
+        Assert.assertFalse(filters.permit(i1, i2, msg(VERB1, MSG1)));
+        Assert.assertEquals(counter.get(), 1);
+        Assert.assertTrue(filters.permit(i1, i2, msg(VERB1, MSG2)));
+        Assert.assertEquals(counter.get(), 2);
+
+        // filter chain gets interrupted because a higher level filter returns no match
+        Assert.assertTrue(filters.permit(i2, i1, msg(VERB1, MSG1)));
+        Assert.assertEquals(counter.get(), 2);
+        Assert.assertTrue(filters.permit(i2, i1, msg(VERB2, MSG1)));
+        Assert.assertEquals(counter.get(), 2);
+        filters.reset();
+
+        filters.allVerbs().from(3, 2).to(2, 1).drop();
+        Assert.assertFalse(filters.permit(i3, i1, msg(VERB1, MSG1)));
+        Assert.assertFalse(filters.permit(i3, i2, msg(VERB1, MSG1)));
+        Assert.assertFalse(filters.permit(i2, i1, msg(VERB1, MSG1)));
+        Assert.assertTrue(filters.permit(i2, i3, msg(VERB1, MSG1)));
+        Assert.assertTrue(filters.permit(i1, i2, msg(VERB1, MSG1)));
+        Assert.assertTrue(filters.permit(i1, i3, msg(VERB1, MSG1)));
+        filters.reset();
+
+        counter.set(0);
+        filters.allVerbs().from(1).to(2).messagesMatching((from, to, msg) -> {
+            counter.incrementAndGet();
+            return false;
+        }).drop();
+        Assert.assertTrue(filters.permit(i1, i2, msg(VERB1, MSG1)));
+        Assert.assertTrue(filters.permit(i1, i3, msg(VERB1, MSG1)));
+        Assert.assertTrue(filters.permit(i1, i2, msg(VERB1, MSG1)));
+        Assert.assertEquals(2, counter.get());
+    }
+
+    IMessage msg(int verb, String msg)
+    {
+        return new IMessage()
+        {
+            public int verb() { return verb; }
+            public byte[] bytes() { return msg.getBytes(); }
+            public int id() { return 0; }
+            public int version() { return 0;  }
+            public InetAddressAndPort from() { return null; }
+        };
+    }
+
+    @Test
+    public void testFilters() throws Throwable
+    {
+        String read = "SELECT * FROM " + KEYSPACE + ".tbl";
+        String write = "INSERT INTO " + KEYSPACE + ".tbl (pk, ck, v) VALUES (1, 1, 1)";
+
+        try (Cluster cluster = Cluster.create(2))
+        {
+            cluster.schemaChange("CREATE KEYSPACE " + KEYSPACE + " WITH replication = {'class': 'SimpleStrategy', 'replication_factor': " + cluster.size() + "};");
+            cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v int, PRIMARY KEY (pk, ck))");
+
+            // Reads and writes are going to time out in both directions
+            cluster.filters().allVerbs().from(1).to(2).drop();
+            for (int i : new int[]{ 1, 2 })
+                assertTimeOut(() -> cluster.coordinator(i).execute(read, ConsistencyLevel.ALL));
+            for (int i : new int[]{ 1, 2 })
+                assertTimeOut(() -> cluster.coordinator(i).execute(write, ConsistencyLevel.ALL));
+
+            cluster.filters().reset();
+            // Reads are going to timeout only when 1 serves as a coordinator
+            cluster.verbs(MessagingService.Verb.RANGE_SLICE).from(1).to(2).drop();
+            assertTimeOut(() -> cluster.coordinator(1).execute(read, ConsistencyLevel.ALL));
+            cluster.coordinator(2).execute(read, ConsistencyLevel.ALL);
+
+            // Writes work in both directions
+            for (int i : new int[]{ 1, 2 })
+                cluster.coordinator(i).execute(write, ConsistencyLevel.ALL);
+        }
+    }
+
+    @Test
+    public void testMessageMatching() throws Throwable
+    {
+        String read = "SELECT * FROM " + KEYSPACE + ".tbl";
+        String write = "INSERT INTO " + KEYSPACE + ".tbl (pk, ck, v) VALUES (1, 1, 1)";
+
+        try (Cluster cluster = Cluster.create(2))
+        {
+            cluster.schemaChange("CREATE KEYSPACE " + KEYSPACE + " WITH replication = {'class': 'SimpleStrategy', 'replication_factor': " + cluster.size() + "};");
+            cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v int, PRIMARY KEY (pk, ck))");
+
+            AtomicInteger counter = new AtomicInteger();
+
+            Set<Integer> verbs = new HashSet<>(Arrays.asList(MessagingService.Verb.RANGE_SLICE.ordinal(),
+                                                             MessagingService.Verb.MUTATION.ordinal()));
+
+            // Reads and writes are going to time out in both directions
+            IMessageFilters.Filter filter = cluster.filters()
+                                                   .allVerbs()
+                                                   .from(1)
+                                                   .to(2)
+                                                   .messagesMatching((from, to, msg) -> {
+                                                       // Decode and verify message on instance; return the result back here
+                                                       Integer id = cluster.get(1).callsOnInstance((IIsolatedExecutor.SerializableCallable<Integer>) () -> {
+                                                           MessageIn decoded = Instance.deserializeMessage(msg).left;
+                                                           if (decoded != null)
+                                                               return (Integer) decoded.verb.ordinal();
+                                                           return -1;
+                                                       }).call();
+                                                       if (id > 0)
+                                                           Assert.assertTrue(verbs.contains(id));
+                                                       counter.incrementAndGet();
+                                                       return false;
+                                                   }).drop();
+
+            for (int i : new int[]{ 1, 2 })
+                cluster.coordinator(i).execute(read, ConsistencyLevel.ALL);
+            for (int i : new int[]{ 1, 2 })
+                cluster.coordinator(i).execute(write, ConsistencyLevel.ALL);
+
+            filter.off();
+            Assert.assertEquals(4, counter.get());
+        }
+    }
+
+    private static void assertTimeOut(Runnable r)
+    {
+        try
+        {
+            r.run();
+            Assert.fail("Should have timed out");
+        }
+        catch (Throwable t)
+        {
+            if (!t.toString().contains("TimeoutException"))
+                throw t;
+            // ignore
+        }
+    }
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org