You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@tomcat.apache.org by ma...@apache.org on 2012/12/19 21:02:04 UTC

svn commit: r1424066 - /tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java

Author: markt
Date: Wed Dec 19 20:02:04 2012
New Revision: 1424066

URL: http://svn.apache.org/viewvc?rev=1424066&view=rev
Log:
WebSocket 1.0 implementation part 17 of many
Improve the handling of fragmented messages

Modified:
    tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java

Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java
URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java?rev=1424066&r1=1424065&r2=1424066&view=diff
==============================================================================
--- tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java (original)
+++ tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java Wed Dec 19 20:02:04 2012
@@ -19,7 +19,6 @@ package org.apache.tomcat.websocket;
 import java.io.EOFException;
 import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.nio.charset.Charset;
 
 import javax.servlet.ServletInputStream;
 import javax.websocket.MessageHandler;
@@ -33,28 +32,40 @@ import org.apache.tomcat.util.res.String
  */
 public class WsFrame {
 
-    private static StringManager sm = StringManager.getManager(Constants.PACKAGE_NAME);
+    private static StringManager sm =
+            StringManager.getManager(Constants.PACKAGE_NAME);
+
+    // Connection level attributes
     private final ServletInputStream sis;
     private final WsSession wsSession;
     private final byte[] inputBuffer;
-    private int pos = 0;
-    private State state = State.NEW_FRAME;
-    private int headerLength = 0;
-    private boolean continutationExpected = false;
+
+    // Attributes of the current message
+    private final ByteBuffer messageBuffer;
+    private boolean continuationExpected = false;
     private boolean textMessage = false;
-    private long payloadSent = 0;
-    private long payloadLength = 0;
-    private boolean fin;
-    private int rsv;
-    private byte opCode;
+
+    // Attributes of the current frame
+    private boolean fin = false;
+    private int rsv = 0;
+    private byte opCode = 0;
+    private int frameStart = 0;
+    private int headerLength = 0;
     private byte[] mask = new byte[4];
-    int maskIndex = 0;
+    private int maskIndex = 0;
+    private long payloadLength = 0;
+    private int payloadRead = 0;
+    private long payloadWritten = 0;
 
+    // Attributes tracking state
+    private State state = State.NEW_FRAME;
+    private int writePos = 0;
 
     public WsFrame(ServletInputStream sis, WsSession wsSession) {
         this.sis = sis;
         this.wsSession = wsSession;
         inputBuffer = new byte[8192];
+        messageBuffer = ByteBuffer.allocate(8192);
     }
 
 
@@ -64,14 +75,15 @@ public class WsFrame {
     public void onDataAvailable() throws IOException {
         while (sis.isReady()) {
             // Fill up the input buffer with as much data as we can
-            int read = sis.read(inputBuffer, pos, inputBuffer.length - pos);
+            int read = sis.read(inputBuffer, writePos,
+                    inputBuffer.length - writePos);
             if (read == 0) {
                 return;
             }
             if (read == -1) {
                 throw new EOFException();
             }
-            pos += read;
+            writePos += read;
             while (true) {
                 if (state == State.NEW_FRAME) {
                     if (!processInitialHeader()) {
@@ -99,15 +111,15 @@ public class WsFrame {
      */
     private boolean processInitialHeader() throws IOException {
         // Need at least two bytes of data to do this
-        if (pos < 2) {
+        if (writePos - frameStart < 2) {
             return false;
         }
-        int b = inputBuffer[0];
+        int b = inputBuffer[frameStart];
         fin = (b & 0x80) > 0;
         rsv = (b & 0x70) >>> 4;
         opCode = (byte) (b & 0x0F);
         if (!isControl()) {
-            if (continutationExpected) {
+            if (continuationExpected) {
                 if (opCode != Constants.OPCODE_CONTINUATION) {
                     // TODO i18n
                     throw new IllegalStateException();
@@ -122,9 +134,9 @@ public class WsFrame {
                     throw new UnsupportedOperationException();
                 }
             }
-            continutationExpected = !fin;
+            continuationExpected = !fin;
         }
-        b = inputBuffer[1];
+        b = inputBuffer[frameStart + 1];
         // Client data must be masked
         if ((b & 0x80) == 0) {
             throw new IOException(sm.getString("wsFrame.notMasked"));
@@ -148,7 +160,7 @@ public class WsFrame {
         } else if (payloadLength == 127) {
             headerLength += 8;
         }
-        if (pos < headerLength) {
+        if (writePos - frameStart < headerLength) {
             return false;
         }
         // Calculate new payload length if necessary
@@ -167,57 +179,69 @@ public class WsFrame {
                 throw new IOException("wsFrame.controlNoFin");
             }
         }
-        System.arraycopy(inputBuffer, headerLength - 4, mask, 0, 4);
+        System.arraycopy(inputBuffer, frameStart + headerLength - 4, mask, 0, 4);
         state = State.DATA;
+        payloadRead = frameStart + headerLength;
         return true;
     }
 
 
     private boolean processData() throws IOException {
+        checkRoomPayload();
+        appendPayloadToMessage();
         if (isControl()) {
-            if (!isPayloadComplete()) {
+            if (writePos < frameStart + headerLength + payloadLength) {
                 return false;
             }
             if (opCode == Constants.OPCODE_CLOSE) {
                 wsSession.close();
             } else if (opCode == Constants.OPCODE_PING) {
-                wsSession.getRemote().sendPong(getPayloadBinary());
+                messageBuffer.flip();
+                wsSession.getRemote().sendPong(messageBuffer);
             } else if (opCode == Constants.OPCODE_PONG) {
                 MessageHandler.Basic<PongMessage> mhPong = wsSession.getPongMessageHandler();
                 if (mhPong != null) {
-                    mhPong.onMessage(new WsPongMessage(getPayloadBinary()));
+                    messageBuffer.flip();
+                    mhPong.onMessage(new WsPongMessage(messageBuffer));
                 }
             } else {
                 // TODO i18n
                 throw new UnsupportedOperationException();
             }
+            newMessage();
             return true;
         }
-        if (!isPayloadComplete()) {
-            if (usePartial()) {
-                sendPayload(false);
-                return false;
-            } else {
-                if (inputBuffer.length - pos > 0) {
-                    return false;
+        if (payloadWritten == payloadLength) {
+            if (continuationExpected) {
+                if (usePartial()) {
+                    messageBuffer.flip();
+                    sendMessage(false);
+                    messageBuffer.clear();
                 }
-                throw new UnsupportedOperationException();
+                newFrame();
+                return true;
+            } else {
+                messageBuffer.flip();
+                sendMessage(true);
+                newMessage();
+                return true;
             }
         } else {
-            sendPayload(true);
+            if (usePartial()) {
+                messageBuffer.flip();
+                sendMessage(false);
+                messageBuffer.clear();
+            }
+            return false;
         }
-        state = State.NEW_FRAME;
-        payloadLength = 0;
-        payloadSent = 0;
-        maskIndex = 0;
-        return true;
     }
 
 
     @SuppressWarnings("unchecked")
-    private void sendPayload(boolean last) {
+    private void sendMessage(boolean last) {
         if (textMessage) {
-            String payload = getPayloadText();
+            String payload =
+                    new String(messageBuffer.array(), 0, messageBuffer.limit());
             MessageHandler mh = wsSession.getTextMessageHandler();
             if (mh != null) {
                 if (mh instanceof MessageHandler.Async<?>) {
@@ -227,27 +251,83 @@ public class WsFrame {
                 }
             }
         } else {
-            ByteBuffer payload = getPayloadBinary();
             MessageHandler mh = wsSession.getBinaryMessageHandler();
             if (mh != null) {
                 if (mh instanceof MessageHandler.Async<?>) {
-                    ((MessageHandler.Async<ByteBuffer>) mh).onMessage(payload,
-                            last);
+                    ((MessageHandler.Async<ByteBuffer>) mh).onMessage(
+                            messageBuffer, last);
                 } else {
-                    ((MessageHandler.Basic<ByteBuffer>) mh).onMessage(payload);
+                    ((MessageHandler.Basic<ByteBuffer>) mh).onMessage(
+                            messageBuffer);
                 }
             }
         }
     }
 
 
-    private boolean isControl() {
-        return (opCode & 0x08) > 0;
+    private void newMessage() {
+        messageBuffer.clear();
+        continuationExpected = false;
+        newFrame();
+    }
+
+
+    private void newFrame() {
+        if (frameStart + headerLength + payloadLength == writePos) {
+            frameStart = 0;
+            writePos = 0;
+        } else {
+            frameStart = frameStart + headerLength + (int) payloadLength;
+        }
+
+        // These get reset in processInitialHeader()
+        // fin, rsv, opCode, headerLength, payloadLength, mask
+        maskIndex = 0;
+        payloadRead = 0;
+        payloadWritten = 0;
+        state = State.NEW_FRAME;
+        checkRoomHeaders();
+    }
+
+
+    private void checkRoomHeaders() {
+        // Is the start of the current frame too near the end of the input
+        // buffer?
+        if (inputBuffer.length - frameStart < 131) {
+            // Limit based on a control frame with a full payload
+            makeRoom();
+        }
+    }
+
+
+    private void checkRoomPayload() throws IOException {
+        long frameSize = headerLength + payloadLength;
+        if (inputBuffer.length - frameStart - frameSize < 0) {
+            if (isControl()) {
+                makeRoom();
+                return;
+            }
+            // Might not be enough room
+            if (usePartial()) {
+                // Not a problem - can use partial messages
+                return;
+            }
+            if (inputBuffer.length < frameSize) {
+                // Never going to work
+                // TODO i18n - buffer too small
+                throw new IOException();
+            }
+            makeRoom();
+        }
     }
 
 
-    private boolean isPayloadComplete() {
-        return (payloadSent + pos - headerLength) >= payloadLength;
+    private void makeRoom() {
+        System.arraycopy(inputBuffer, frameStart, inputBuffer, 0,
+                writePos - frameStart);
+        writePos = writePos - frameStart;
+        payloadRead = payloadRead - frameStart;
+        frameStart = 0;
     }
 
 
@@ -271,32 +351,17 @@ public class WsFrame {
     }
 
 
-    private ByteBuffer getPayloadBinary() {
-        int end;
-        if (isPayloadComplete()) {
-            end = (int) (payloadLength - payloadSent) + headerLength;
-        } else {
-            end = pos;
-        }
-        ByteBuffer result = ByteBuffer.allocate(end - headerLength);
-        for (int i = headerLength; i < end; i++) {
-            result.put(i - headerLength,
-                    (byte) ((inputBuffer[i] ^ mask[maskIndex]) & 0xFF));
+    private void appendPayloadToMessage() {
+        while (payloadWritten < payloadLength && payloadRead < writePos) {
+            byte b = (byte) ((inputBuffer[payloadRead] ^ mask[maskIndex]) & 0xFF);
             maskIndex++;
             if (maskIndex == 4) {
                 maskIndex = 0;
             }
+            payloadRead++;
+            payloadWritten++;
+            messageBuffer.put(b);
         }
-        // May have read past end of current frame into next
-        pos = 0;
-        headerLength = 0;
-        return result;
-    }
-
-
-    private String getPayloadText() {
-        ByteBuffer bb = getPayloadBinary();
-        return new String(bb.array(), Charset.forName("UTF-8"));
     }
 
 
@@ -315,6 +380,12 @@ public class WsFrame {
         return result;
     }
 
+
+    private boolean isControl() {
+        return (opCode & 0x08) > 0;
+    }
+
+
     private static enum State {
         NEW_FRAME, PARTIAL_HEADER, DATA
     }



---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@tomcat.apache.org
For additional commands, e-mail: dev-help@tomcat.apache.org