You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by sl...@apache.org on 2012/11/08 14:29:46 UTC

git commit: make binary protocol more reliable to bad clients

Updated Branches:
  refs/heads/trunk c016b3184 -> e27a95587


make binary protocol more reliable to bad clients


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/e27a9558
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/e27a9558
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/e27a9558

Branch: refs/heads/trunk
Commit: e27a9558737f4f0d18104fd57169f681bcfc1703
Parents: c016b31
Author: Sylvain Lebresne <sy...@datastax.com>
Authored: Thu Nov 8 14:29:12 2012 +0100
Committer: Sylvain Lebresne <sy...@datastax.com>
Committed: Thu Nov 8 14:29:34 2012 +0100

----------------------------------------------------------------------
 .../org/apache/cassandra/transport/CBUtil.java     |   20 +++++-
 src/java/org/apache/cassandra/transport/Event.java |   19 ++---
 src/java/org/apache/cassandra/transport/Frame.java |   40 +++++++++--
 .../org/apache/cassandra/transport/Message.java    |   15 ++++-
 .../cassandra/transport/messages/ErrorMessage.java |    7 +-
 .../transport/messages/RegisterMessage.java        |   15 ++--
 .../transport/messages/StartupMessage.java         |   57 ++++++++-------
 7 files changed, 113 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/e27a9558/src/java/org/apache/cassandra/transport/CBUtil.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/CBUtil.java b/src/java/org/apache/cassandra/transport/CBUtil.java
index 0d3be07..2ad8c72 100644
--- a/src/java/org/apache/cassandra/transport/CBUtil.java
+++ b/src/java/org/apache/cassandra/transport/CBUtil.java
@@ -145,6 +145,24 @@ public abstract class CBUtil
         return ConsistencyLevel.fromCode(cb.readUnsignedShort());
     }
 
+    public static <T extends Enum<T>> T readEnumValue(Class<T> enumType, ChannelBuffer cb)
+    {
+        String value = CBUtil.readString(cb);
+        try
+        {
+            return Enum.valueOf(enumType, value.toUpperCase());
+        }
+        catch (IllegalArgumentException e)
+        {
+            throw new ProtocolException(String.format("Invalid value '%s' for %s", value, enumType.getSimpleName()));
+        }
+    }
+
+    public static <T extends Enum<T>> ChannelBuffer enumValueToCB(T enumValue)
+    {
+        return stringToCB(enumValue.toString());
+    }
+
     public static ChannelBuffer uuidToCB(UUID uuid)
     {
         return ChannelBuffers.wrappedBuffer(UUIDGen.decompose(uuid));
@@ -166,7 +184,7 @@ public abstract class CBUtil
     public static List<String> readStringList(ChannelBuffer cb)
     {
         int length = cb.readUnsignedShort();
-        List<String> l = new ArrayList<String>();
+        List<String> l = new ArrayList<String>(length);
         for (int i = 0; i < length; i++)
             l.add(readString(cb));
         return l;

http://git-wip-us.apache.org/repos/asf/cassandra/blob/e27a9558/src/java/org/apache/cassandra/transport/Event.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/Event.java b/src/java/org/apache/cassandra/transport/Event.java
index 33b08f4..d973a59 100644
--- a/src/java/org/apache/cassandra/transport/Event.java
+++ b/src/java/org/apache/cassandra/transport/Event.java
@@ -36,7 +36,7 @@ public abstract class Event
 
     public static Event deserialize(ChannelBuffer cb)
     {
-        switch (Enum.valueOf(Type.class, CBUtil.readString(cb).toUpperCase()))
+        switch (CBUtil.readEnumValue(Type.class, cb))
         {
             case TOPOLOGY_CHANGE:
                 return TopologyChange.deserializeEvent(cb);
@@ -50,8 +50,7 @@ public abstract class Event
 
     public ChannelBuffer serialize()
     {
-        return ChannelBuffers.wrappedBuffer(CBUtil.stringToCB(type.toString()),
-                                            serializeEvent());
+        return ChannelBuffers.wrappedBuffer(CBUtil.enumValueToCB(type), serializeEvent());
     }
 
     protected abstract ChannelBuffer serializeEvent();
@@ -88,15 +87,14 @@ public abstract class Event
         // Assumes the type has already by been deserialized
         private static TopologyChange deserializeEvent(ChannelBuffer cb)
         {
-            Change change = Enum.valueOf(Change.class, CBUtil.readString(cb).toUpperCase());
+            Change change = CBUtil.readEnumValue(Change.class, cb);
             InetSocketAddress node = CBUtil.readInet(cb);
             return new TopologyChange(change, node);
         }
 
         protected ChannelBuffer serializeEvent()
         {
-            return ChannelBuffers.wrappedBuffer(CBUtil.stringToCB(change.toString()),
-                                                CBUtil.inetToCB(node));
+            return ChannelBuffers.wrappedBuffer(CBUtil.enumValueToCB(change), CBUtil.inetToCB(node));
         }
 
         @Override
@@ -133,15 +131,14 @@ public abstract class Event
         // Assumes the type has already by been deserialized
         private static StatusChange deserializeEvent(ChannelBuffer cb)
         {
-            Status status = Enum.valueOf(Status.class, CBUtil.readString(cb).toUpperCase());
+            Status status = CBUtil.readEnumValue(Status.class, cb);
             InetSocketAddress node = CBUtil.readInet(cb);
             return new StatusChange(status, node);
         }
 
         protected ChannelBuffer serializeEvent()
         {
-            return ChannelBuffers.wrappedBuffer(CBUtil.stringToCB(status.toString()),
-                                                CBUtil.inetToCB(node));
+            return ChannelBuffers.wrappedBuffer(CBUtil.enumValueToCB(status), CBUtil.inetToCB(node));
         }
 
         @Override
@@ -175,7 +172,7 @@ public abstract class Event
         // Assumes the type has already by been deserialized
         private static SchemaChange deserializeEvent(ChannelBuffer cb)
         {
-            Change change = Enum.valueOf(Change.class, CBUtil.readString(cb).toUpperCase());
+            Change change = CBUtil.readEnumValue(Change.class, cb);
             String keyspace = CBUtil.readString(cb);
             String table = CBUtil.readString(cb);
             return new SchemaChange(change, keyspace, table);
@@ -183,7 +180,7 @@ public abstract class Event
 
         protected ChannelBuffer serializeEvent()
         {
-            return ChannelBuffers.wrappedBuffer(CBUtil.stringToCB(change.toString()),
+            return ChannelBuffers.wrappedBuffer(CBUtil.enumValueToCB(change),
                                                 CBUtil.stringToCB(keyspace),
                                                 CBUtil.stringToCB(table));
         }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/e27a9558/src/java/org/apache/cassandra/transport/Frame.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/Frame.java b/src/java/org/apache/cassandra/transport/Frame.java
index 6ff072c..be9df1a 100644
--- a/src/java/org/apache/cassandra/transport/Frame.java
+++ b/src/java/org/apache/cassandra/transport/Frame.java
@@ -26,7 +26,7 @@ import org.jboss.netty.buffer.ChannelBuffers;
 import org.jboss.netty.channel.*;
 import org.jboss.netty.handler.codec.oneone.OneToOneDecoder;
 import org.jboss.netty.handler.codec.oneone.OneToOneEncoder;
-import org.jboss.netty.handler.codec.frame.LengthFieldBasedFrameDecoder;
+import org.jboss.netty.handler.codec.frame.*;
 
 import org.apache.cassandra.utils.ByteBufferUtil;
 
@@ -144,7 +144,7 @@ public class Frame
 
         public Decoder(Connection.Tracker tracker, Connection.Factory factory)
         {
-            super(MAX_FRAME_LENTH, 4, 4);
+            super(MAX_FRAME_LENTH, 4, 4, 0, 0, true);
             this.connection = factory.newConnection(tracker);
         }
 
@@ -159,12 +159,40 @@ public class Frame
         protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer)
         throws Exception
         {
-            ChannelBuffer frame = (ChannelBuffer) super.decode(ctx, channel, buffer);
-            if (frame == null)
+            try
             {
-                return null;
+                // We must at least validate that the frame version is something we support/know and it doesn't hurt to
+                // check the opcode is not garbage. And we should do that indenpently of what is the the bytes corresponding
+                // to the frame length are, i.e. we shouldn't wait for super.decode() to return non-null.
+                if (buffer.readableBytes() == 0)
+                    return null;
+
+                int firstByte = buffer.getByte(0);
+                Message.Direction direction = Message.Direction.extractFromVersion(firstByte);
+                int version = firstByte & 0x7F;
+                // We really only support the current version so far
+                if (version != Header.CURRENT_VERSION)
+                    throw new ProtocolException("Invalid or unsupported protocol version: " + version);
+
+                // Validate the opcode
+                if (buffer.readableBytes() >= 4)
+                    Message.Type.fromOpcode(buffer.getByte(3), direction);
+
+                ChannelBuffer frame = (ChannelBuffer) super.decode(ctx, channel, buffer);
+                if (frame == null)
+                {
+                    return null;
+                }
+                return Frame.create(frame, connection);
+            }
+            catch (CorruptedFrameException e)
+            {
+                throw new ProtocolException(e.getMessage());
+            }
+            catch (TooLongFrameException e)
+            {
+                throw new ProtocolException(e.getMessage());
             }
-            return Frame.create(frame, connection);
         }
     }
 

http://git-wip-us.apache.org/repos/asf/cassandra/blob/e27a9558/src/java/org/apache/cassandra/transport/Message.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/Message.java b/src/java/org/apache/cassandra/transport/Message.java
index a60a27a..d4d3da6 100644
--- a/src/java/org/apache/cassandra/transport/Message.java
+++ b/src/java/org/apache/cassandra/transport/Message.java
@@ -293,11 +293,22 @@ public abstract class Message
         }
 
         @Override
-        public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e)
+        public void exceptionCaught(final ChannelHandlerContext ctx, ExceptionEvent e)
         throws Exception
         {
             if (ctx.getChannel().isOpen())
-                ctx.getChannel().write(ErrorMessage.fromException(e.getCause()));
+            {
+                ChannelFuture future = ctx.getChannel().write(ErrorMessage.fromException(e.getCause()));
+                // On protocol exception, close the channel as soon as the message have been sent
+                if (e.getCause() instanceof ProtocolException)
+                {
+                    future.addListener(new ChannelFutureListener() {
+                        public void operationComplete(ChannelFuture future) {
+                            ctx.getChannel().close();
+                        }
+                    });
+                }
+            }
         }
     }
 }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/e27a9558/src/java/org/apache/cassandra/transport/messages/ErrorMessage.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/messages/ErrorMessage.java b/src/java/org/apache/cassandra/transport/messages/ErrorMessage.java
index 114c2fb..c8ff879 100644
--- a/src/java/org/apache/cassandra/transport/messages/ErrorMessage.java
+++ b/src/java/org/apache/cassandra/transport/messages/ErrorMessage.java
@@ -76,7 +76,7 @@ public class ErrorMessage extends Message.Response
                     break;
                 case WRITE_TIMEOUT:
                 case READ_TIMEOUT:
-                    ConsistencyLevel cl = Enum.valueOf(ConsistencyLevel.class, CBUtil.readString(body));
+                    ConsistencyLevel cl = CBUtil.readConsistencyLevel(body);
                     int received = body.readInt();
                     int blockFor = body.readInt();
                     if (code == ExceptionCode.WRITE_TIMEOUT)
@@ -141,15 +141,14 @@ public class ErrorMessage extends Message.Response
                     RequestTimeoutException rte = (RequestTimeoutException)msg.error;
                     boolean isWrite = msg.error.code() == ExceptionCode.WRITE_TIMEOUT;
 
-                    ByteBuffer rteCl = ByteBufferUtil.bytes(rte.consistency.toString());
+                    ChannelBuffer rteCl = CBUtil.consistencyLevelToCB(rte.consistency);
                     ByteBuffer writeType = isWrite
                                          ? ByteBufferUtil.bytes(((WriteTimeoutException)rte).writeType.toString())
                                          : null;
 
                     int extraSize = isWrite  ? 2 + writeType.remaining() : 1;
-                    acb = ChannelBuffers.buffer(2 + rteCl.remaining() + 8 + extraSize);
+                    acb = ChannelBuffers.buffer(rteCl.writableBytes() + 8 + extraSize);
 
-                    acb.writeShort((short)rteCl.remaining());
                     acb.writeBytes(rteCl);
                     acb.writeInt(rte.received);
                     acb.writeInt(rte.blockFor);

http://git-wip-us.apache.org/repos/asf/cassandra/blob/e27a9558/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java b/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java
index 9e46e92..61e11af 100644
--- a/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java
+++ b/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java
@@ -32,20 +32,19 @@ public class RegisterMessage extends Message.Request
     {
         public RegisterMessage decode(ChannelBuffer body)
         {
-            List<String> l = CBUtil.readStringList(body);
-            List<Event.Type> eventTypes = new ArrayList<Event.Type>(l.size());
-            for (String s : l)
-                eventTypes.add(Enum.valueOf(Event.Type.class, s.toUpperCase()));
+            int length = body.readUnsignedShort();
+            List<Event.Type> eventTypes = new ArrayList<Event.Type>(length);
+            for (int i = 0; i < length; ++i)
+                eventTypes.add(CBUtil.readEnumValue(Event.Type.class, body));
             return new RegisterMessage(eventTypes);
         }
 
         public ChannelBuffer encode(RegisterMessage msg)
         {
-            List<String> l = new ArrayList<String>(msg.eventTypes.size());
-            for (Event.Type type : msg.eventTypes)
-                l.add(type.toString());
             ChannelBuffer cb = ChannelBuffers.dynamicBuffer();
-            CBUtil.writeStringList(cb, l);
+            cb.writeShort(msg.eventTypes.size());
+            for (Event.Type type : msg.eventTypes)
+                cb.writeBytes(CBUtil.enumValueToCB(type));
             return cb;
         }
     };

http://git-wip-us.apache.org/repos/asf/cassandra/blob/e27a9558/src/java/org/apache/cassandra/transport/messages/StartupMessage.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/messages/StartupMessage.java b/src/java/org/apache/cassandra/transport/messages/StartupMessage.java
index e781517..7ef1504 100644
--- a/src/java/org/apache/cassandra/transport/messages/StartupMessage.java
+++ b/src/java/org/apache/cassandra/transport/messages/StartupMessage.java
@@ -68,41 +68,42 @@ public class StartupMessage extends Message.Request
 
     public Message.Response execute(QueryState state)
     {
-        try
-        {
-            ClientState cState = state.getClientState();
-            String cqlVersion = options.get(CQL_VERSION);
-            if (cqlVersion == null)
-                throw new ProtocolException("Missing value CQL_VERSION in STARTUP message");
+        ClientState cState = state.getClientState();
+        String cqlVersion = options.get(CQL_VERSION);
+        if (cqlVersion == null)
+            throw new ProtocolException("Missing value CQL_VERSION in STARTUP message");
 
+        try 
+        {
             cState.setCQLVersion(cqlVersion);
-            if (cState.getCQLVersion().compareTo(new SemanticVersion("2.99.0")) < 0)
-                throw new ProtocolException(String.format("CQL version %s is not support by the binary protocol (supported version are >= 3.0.0)", cqlVersion));
+        }
+        catch (InvalidRequestException e)
+        {
+            throw new ProtocolException(e.getMessage());
+        }
 
-            if (options.containsKey(COMPRESSION))
+        if (cState.getCQLVersion().compareTo(new SemanticVersion("2.99.0")) < 0)
+            throw new ProtocolException(String.format("CQL version %s is not supported by the binary protocol (supported version are >= 3.0.0)", cqlVersion));
+
+        if (options.containsKey(COMPRESSION))
+        {
+            String compression = options.get(COMPRESSION).toLowerCase();
+            if (compression.equals("snappy"))
             {
-                String compression = options.get(COMPRESSION).toLowerCase();
-                if (compression.equals("snappy"))
-                {
-                    if (FrameCompressor.SnappyCompressor.instance == null)
-                        throw new ProtocolException("This instance does not support Snappy compression");
-                    connection.setCompressor(FrameCompressor.SnappyCompressor.instance);
-                }
-                else
-                {
-                    throw new ProtocolException(String.format("Unknown compression algorithm: %s", compression));
-                }
+                if (FrameCompressor.SnappyCompressor.instance == null)
+                    throw new ProtocolException("This instance does not support Snappy compression");
+                connection.setCompressor(FrameCompressor.SnappyCompressor.instance);
             }
-
-            if (cState.isLogged())
-                return new ReadyMessage();
             else
-                return new AuthenticateMessage(DatabaseDescriptor.getAuthenticator().getClass().getName());
-        }
-        catch (InvalidRequestException e)
-        {
-            return ErrorMessage.fromException(new ProtocolException(e.getMessage()));
+            {
+                throw new ProtocolException(String.format("Unknown compression algorithm: %s", compression));
+            }
         }
+
+        if (cState.isLogged())
+            return new ReadyMessage();
+        else
+            return new AuthenticateMessage(DatabaseDescriptor.getAuthenticator().getClass().getName());
     }
 
     @Override