You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by sa...@apache.org on 2021/01/22 14:02:13 UTC

[cassandra] branch trunk updated: Restore validation of message protocol version

This is an automated email from the ASF dual-hosted git repository.

samt pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/trunk by this push:
     new a0441eb  Restore validation of message protocol version
a0441eb is described below

commit a0441eb66b1976865c105069e9964104720db7fb
Author: Sam Tunnicliffe <sa...@beobal.com>
AuthorDate: Fri Jan 15 16:26:01 2021 +0000

    Restore validation of message protocol version
    
    Patch by Sam Tunnicliffe; reviewed by Mick Semb Wever for CASSANDRA-16374
---
 CHANGES.txt                                        |  1 +
 .../cassandra/transport/CQLMessageHandler.java     | 23 +++++++---
 .../cassandra/transport/PipelineConfigurator.java  |  1 +
 .../apache/cassandra/transport/PreV5Handlers.java  | 23 +++++++---
 .../apache/cassandra/transport/SimpleClient.java   |  1 +
 .../transport/ProtocolNegotiationTest.java         | 50 +++++++++++++++++++++-
 6 files changed, 87 insertions(+), 12 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index a946f1d..7f2410b 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 4.0-beta5
+ * Restore validation of each message's protocol version (CASSANDRA-16374)
  * Upgrade netty and chronicle-queue dependencies to get Auditing and native library loading working on arm64 architectures (CASSANDRA-16384,CASSANDRA-16392)
  * Release StreamingTombstoneHistogramBuilder spool when switching writers (CASSANDRA-14834)
  * Correct memtable on-heap size calculations to match actual use (CASSANDRA-16318)
diff --git a/src/java/org/apache/cassandra/transport/CQLMessageHandler.java b/src/java/org/apache/cassandra/transport/CQLMessageHandler.java
index 48cdf64..34e619a 100644
--- a/src/java/org/apache/cassandra/transport/CQLMessageHandler.java
+++ b/src/java/org/apache/cassandra/transport/CQLMessageHandler.java
@@ -58,9 +58,7 @@ import static org.apache.cassandra.utils.MonotonicClock.approxTime;
  * frames on the event loop thread and pass them on to the same consumer.
  *
  * # Flow control (backpressure)
- *
  * The size of an incoming message is explicit in the {@link Envelope.Header}.
- * {@link org.apache.cassandra.net.Message.Serializer#inferMessageSize(ByteBuffer, int, int, int)}.
  *
  * By default, every connection has 1MiB of exlusive permits available before needing to access the per-endpoint
  * and global reserves. By default, those reserves are sized proportionally to the heap - 2.5% of heap per-endpoint
@@ -68,6 +66,10 @@ import static org.apache.cassandra.utils.MonotonicClock.approxTime;
  *
  * Permits are held while CQL messages are processed and released after the response has been encoded into the
  * buffers of the response frame.
+ *
+ * A connection level option (THROW_ON_OVERLOAD) allows clients to choose the backpressure strategy when a connection
+ * has exceeded the maximum number of allowed permits. The choices are to either pause reads from the incoming socket
+ * and allow TCP backpressure to do the work, or to throw an explict exception and rely on the client to back off.
  */
 public class CQLMessageHandler<M extends Message> extends AbstractMessageHandler
 {
@@ -82,6 +84,7 @@ public class CQLMessageHandler<M extends Message> extends AbstractMessageHandler
     private final MessageConsumer<M> dispatcher;
     private final ErrorHandler errorHandler;
     private final boolean throwOnOverload;
+    private final ProtocolVersion version;
 
     long channelPayloadBytesInFlight;
 
@@ -96,6 +99,7 @@ public class CQLMessageHandler<M extends Message> extends AbstractMessageHandler
     }
 
     CQLMessageHandler(Channel channel,
+                      ProtocolVersion version,
                       FrameDecoder decoder,
                       Envelope.Decoder envelopeDecoder,
                       Message.Decoder<M> messageDecoder,
@@ -122,11 +126,12 @@ public class CQLMessageHandler<M extends Message> extends AbstractMessageHandler
         this.dispatcher         = dispatcher;
         this.errorHandler       = errorHandler;
         this.throwOnOverload    = throwOnOverload;
+        this.version            = version;
     }
 
     protected boolean processOneContainedMessage(ShareableBytes bytes, Limit endpointReserve, Limit globalReserve)
     {
-        Envelope.Header header = extractHeader(bytes);
+        Envelope.Header header = extractHeader(bytes.get());
         // null indicates a failure to extract the CQL message header.
         // This will trigger a protocol exception and closing of the connection.
         if (null == header)
@@ -182,11 +187,17 @@ public class CQLMessageHandler<M extends Message> extends AbstractMessageHandler
         ClientMessageSizeMetrics.bytesReceivedPerRequest.update(messageSize + Envelope.Header.LENGTH);
     }
 
-    private Envelope.Header extractHeader(ShareableBytes bytes)
+    private Envelope.Header extractHeader(ByteBuffer buf)
     {
         try
         {
-            return envelopeDecoder.extractHeader(bytes.get());
+            Envelope.Header header = envelopeDecoder.extractHeader(buf);
+            if (header.version != version)
+                handleError(new ProtocolException(String.format("Invalid message version. Got %s but previous" +
+                                                                "messages on this connection had version %s",
+                                                                header.version, version)));
+
+            return header;
         }
         catch (Throwable t)
         {
@@ -319,7 +330,7 @@ public class CQLMessageHandler<M extends Message> extends AbstractMessageHandler
         ByteBuffer buf = bytes.get();
         try
         {
-            Envelope.Header header = envelopeDecoder.extractHeader(buf);
+            Envelope.Header header = extractHeader(buf);
             // max CQL message size defaults to 256mb, so should be safe to downcast
             int messageSize = Ints.checkedCast(header.bodySizeInBytes);
             receivedBytes += buf.remaining();
diff --git a/src/java/org/apache/cassandra/transport/PipelineConfigurator.java b/src/java/org/apache/cassandra/transport/PipelineConfigurator.java
index 82865f2..0250f13 100644
--- a/src/java/org/apache/cassandra/transport/PipelineConfigurator.java
+++ b/src/java/org/apache/cassandra/transport/PipelineConfigurator.java
@@ -283,6 +283,7 @@ public class PipelineConfigurator
         CQLMessageHandler.MessageConsumer<Message.Request> messageConsumer = messageConsumer();
         CQLMessageHandler<Message.Request> processor =
             new CQLMessageHandler<>(ctx.channel(),
+                                    version,
                                     frameDecoder,
                                     envelopeDecoder,
                                     messageDecoder,
diff --git a/src/java/org/apache/cassandra/transport/PreV5Handlers.java b/src/java/org/apache/cassandra/transport/PreV5Handlers.java
index b10722b..d8a66ee 100644
--- a/src/java/org/apache/cassandra/transport/PreV5Handlers.java
+++ b/src/java/org/apache/cassandra/transport/PreV5Handlers.java
@@ -180,6 +180,14 @@ public class PreV5Handlers
         {
             try
             {
+                ProtocolVersion version = getConnectionVersion(ctx);
+                if (source.header.version != version)
+                {
+                    throw new ProtocolException(
+                        String.format("Invalid message version. Got %s but previous " +
+                                      "messages on this connection had version %s",
+                                      source.header.version, version));
+                }
                 results.add(Message.Decoder.decodeMessage(ctx.channel(), source));
             }
             catch (Throwable ex)
@@ -199,12 +207,9 @@ public class PreV5Handlers
     {
         public static final ProtocolEncoder instance = new ProtocolEncoder();
         private ProtocolEncoder(){}
-
         public void encode(ChannelHandlerContext ctx, Message source, List results)
         {
-            Connection connection = ctx.channel().attr(Connection.attributeKey).get();
-            // The only case the connection can be null is when we send the initial STARTUP message (client side thus)
-            ProtocolVersion version = connection == null ? ProtocolVersion.CURRENT : connection.getVersion();
+            ProtocolVersion version = getConnectionVersion(ctx);
             results.add(source.encode(version));
         }
     }
@@ -227,7 +232,7 @@ public class PreV5Handlers
             ErrorMessage errorMessage = ErrorMessage.fromException(cause, handler);
             if (ctx.channel().isOpen())
             {
-                ChannelFuture future = ctx.writeAndFlush(errorMessage.encode(ProtocolVersion.CURRENT));
+                ChannelFuture future = ctx.writeAndFlush(errorMessage.encode(getConnectionVersion(ctx)));
                 // On protocol exception, close the channel as soon as the message have been sent
                 if (cause instanceof ProtocolException)
                     future.addListener((ChannelFutureListener) f -> ctx.close());
@@ -235,4 +240,12 @@ public class PreV5Handlers
             JVMStabilityInspector.inspectThrowable(cause);
         }
     }
+
+    private static ProtocolVersion getConnectionVersion(ChannelHandlerContext ctx)
+    {
+        Connection connection = ctx.channel().attr(Connection.attributeKey).get();
+        // The only case the connection can be null is when we send the initial STARTUP message
+        return connection == null ? ProtocolVersion.CURRENT : connection.getVersion();
+    }
+
 }
diff --git a/src/java/org/apache/cassandra/transport/SimpleClient.java b/src/java/org/apache/cassandra/transport/SimpleClient.java
index 5ad4c17..13a5e17 100644
--- a/src/java/org/apache/cassandra/transport/SimpleClient.java
+++ b/src/java/org/apache/cassandra/transport/SimpleClient.java
@@ -472,6 +472,7 @@ public class SimpleClient implements Closeable
 
             CQLMessageHandler<Message.Response> processor =
                 new CQLMessageHandler<Message.Response>(ctx.channel(),
+                                        version,
                                         frameDecoder,
                                         envelopeDecoder,
                                         messageDecoder,
diff --git a/test/unit/org/apache/cassandra/transport/ProtocolNegotiationTest.java b/test/unit/org/apache/cassandra/transport/ProtocolNegotiationTest.java
index f33d8e6..e16959a 100644
--- a/test/unit/org/apache/cassandra/transport/ProtocolNegotiationTest.java
+++ b/test/unit/org/apache/cassandra/transport/ProtocolNegotiationTest.java
@@ -31,6 +31,7 @@ import com.datastax.driver.core.Session;
 import org.apache.cassandra.cql3.CQLTester;
 import org.apache.cassandra.cql3.QueryOptions;
 import org.apache.cassandra.cql3.QueryProcessor;
+import org.apache.cassandra.exceptions.TransportException;
 import org.apache.cassandra.transport.messages.OptionsMessage;
 import org.apache.cassandra.transport.messages.QueryMessage;
 import org.apache.cassandra.transport.messages.StartupMessage;
@@ -97,11 +98,16 @@ public class ProtocolNegotiationTest extends CQLTester
         ProtocolVersion.SUPPORTED.forEach(this::testStreamIdsAcrossNegotiation);
     }
 
+    @Test
+    public void validateReceivedMessageVersionMatchesNegotiated()
+    {
+        ProtocolVersion.SUPPORTED.forEach(this::validateMessageVersion);
+    }
+
     private void testStreamIdsAcrossNegotiation(ProtocolVersion version)
     {
         long seed = System.currentTimeMillis();
         Random random = new Random(seed);
-        reinitializeNetwork();
         SimpleClient.Builder builder = SimpleClient.builder(nativeAddr.getHostAddress(), nativePort);
         if (version.isBeta())
             builder.useBeta();
@@ -185,4 +191,46 @@ public class ProtocolNegotiationTest extends CQLTester
         }
     }
 
+    private void validateMessageVersion(ProtocolVersion version)
+    {
+        SimpleClient.Builder builder = SimpleClient.builder(nativeAddr.getHostAddress(), nativePort)
+                                                   .protocolVersion(version);
+        if (version.isBeta())
+            builder.useBeta();
+
+        Random r = new Random();
+        ProtocolVersion wrongVersion = version;
+        while (wrongVersion.isSmallerThan(ProtocolVersion.MIN_SUPPORTED_VERSION) || wrongVersion == version)
+            wrongVersion = ProtocolVersion.values()[r.nextInt(ProtocolVersion.values().length - 1)];
+
+        try (SimpleClient client = builder.build().connect(false))
+        {
+            // The connection has been negotiated to use $version. Force the next message to be
+            // encoded with a different version and it should trigger a ProtocolException
+            final ProtocolVersion v = wrongVersion;
+            QueryMessage query = new QueryMessage("SELECT * FROM system.local", QueryOptions.DEFAULT)
+            {
+                @Override
+                public Envelope encode(ProtocolVersion originalVersion)
+                {
+                    return super.encode(v);
+                }
+            };
+            try
+            {
+                client.execute(query);
+                fail("Expected a protocol exception");
+            }
+            catch (RuntimeException e)
+            {
+                assertTrue(e.getCause() instanceof TransportException);
+                assertTrue(e.getCause().getMessage().startsWith("Invalid message version"));
+            }
+        }
+        catch (IOException e)
+        {
+            e.printStackTrace();
+            fail("Error establishing connection");
+        }
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org