You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@tomcat.apache.org by ma...@apache.org on 2013/02/06 20:06:57 UTC

svn commit: r1443135 - in /tomcat/trunk: java/org/apache/tomcat/websocket/ java/org/apache/tomcat/websocket/server/ test/org/apache/tomcat/websocket/

Author: markt
Date: Wed Feb  6 19:06:56 2013
New Revision: 1443135

URL: http://svn.apache.org/viewvc?rev=1443135&view=rev
Log:
Refactor the RemoteEndpoint implementation.
- Add support for masking client data
- Add support batching (a.k.a. buffering) messages
- Provide building blocks for Stream, Writer, etc. support

Modified:
    tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java
    tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java
    tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java
    tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java
    tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java

Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java
URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java?rev=1443135&r1=1443134&r2=1443135&view=diff
==============================================================================
--- tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java (original)
+++ tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java Wed Feb  6 19:06:56 2013
@@ -21,7 +21,6 @@ import java.io.OutputStream;
 import java.io.Writer;
 import java.nio.ByteBuffer;
 import java.nio.CharBuffer;
-import java.nio.channels.CompletionHandler;
 import java.nio.charset.Charset;
 import java.nio.charset.CharsetEncoder;
 import java.nio.charset.CoderResult;
@@ -31,6 +30,8 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
 
 import javax.websocket.EncodeException;
 import javax.websocket.RemoteEndpoint;
@@ -44,17 +45,18 @@ public abstract class WsRemoteEndpointBa
     private static final StringManager sm =
             StringManager.getManager(Constants.PACKAGE_NAME);
 
-    // TODO Make the size of these buffers configurable
-    private final ByteBuffer intermediateBuffer = ByteBuffer.allocate(8192);
-    protected final ByteBuffer outputBuffer = ByteBuffer.allocate(8192);
-    private final AtomicBoolean charToByteInProgress = new AtomicBoolean(false);
-    private final CharsetEncoder encoder = Charset.forName("UTF8").newEncoder();
+    private final ReentrantLock writeLock = new ReentrantLock();
+    private final Condition notInProgress = writeLock.newCondition();
+    // Must hold writeLock above to modify state
     private final MessageSendStateMachine state = new MessageSendStateMachine();
-
+    // Max size of WebSocket header is 14 bytes
+    private final ByteBuffer headerBuffer = ByteBuffer.allocate(14);
+    private final ByteBuffer outputBuffer = ByteBuffer.allocate(8192);
+    private final CharsetEncoder encoder = Charset.forName("UTF8").newEncoder();
+    private final ByteBuffer encoderBuffer = ByteBuffer.allocate(8192);
+    private AtomicBoolean batchingAllowed = new AtomicBoolean(false);
     private volatile long asyncSendTimeout = -1;
 
-    protected ByteBuffer payload = null;
-
 
     @Override
     public long getAsyncSendTimeout() {
@@ -70,66 +72,79 @@ public abstract class WsRemoteEndpointBa
 
     @Override
     public void setBatchingAllowed(boolean batchingAllowed) {
-        // TODO Auto-generated method stub
+        boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed);
 
+        if (oldValue && !batchingAllowed) {
+            // Just disabled batched. Must flush.
+            flushBatch();
+        }
     }
 
 
     @Override
     public boolean getBatchingAllowed() {
-        // TODO Auto-generated method stub
-        return false;
+        return batchingAllowed.get();
     }
 
 
     @Override
     public void flushBatch() {
-        // TODO Auto-generated method stub
-
+        // Have to hold lock to flush output buffer
+        writeLock.lock();
+        try {
+            while (state.isInProgress()) {
+                notInProgress.await();
+            }
+            FutureToSendHandler f2sh = new FutureToSendHandler();
+            doWrite(f2sh, outputBuffer);
+            f2sh.get();
+        } catch (InterruptedException | ExecutionException e) {
+            // TODO Log this? Runtime exception? Something else?
+        } finally {
+            writeLock.unlock();
+        }
     }
 
 
     @Override
-    public final void sendString(String text) throws IOException {
-        sendPartialString(text, true);
+    public void sendBytes(ByteBuffer data) throws IOException {
+        Future<SendResult> f = sendBytesByFuture(data);
+        try {
+            SendResult sr = f.get();
+            if (!sr.isOK()) {
+                if (sr.getException() == null) {
+                    throw new IOException();
+                } else {
+                    throw new IOException(sr.getException());
+                }
+            }
+        } catch (InterruptedException | ExecutionException e) {
+            throw new IOException(e);
+        }
     }
 
 
     @Override
-    public final void sendBytes(ByteBuffer data) throws IOException {
-        sendPartialBytes(data, true);
+    public Future<SendResult> sendBytesByFuture(ByteBuffer data) {
+        FutureToSendHandler f2sh = new FutureToSendHandler();
+        sendBytesByCompletion(data, f2sh);
+        return f2sh;
     }
 
 
     @Override
-    public void sendPartialString(String fragment, boolean isLast)
-            throws IOException {
-
-        // The toBytes buffer needs to be protected from multiple threads and
-        // the state check happens to late.
-        if (!charToByteInProgress.compareAndSet(false, true)) {
-            throw new IllegalStateException(sm.getString(
-                    "wsRemoteEndpoint.concurrentMessageSend"));
+    public void sendBytesByCompletion(ByteBuffer data, SendHandler completion) {
+        boolean locked = writeLock.tryLock();
+        if (!locked) {
+            throw new IllegalStateException(
+                    sm.getString("wsRemoteEndpoint.concurrentMessageSend"));
         }
-
         try {
-            encoder.reset();
-            intermediateBuffer.clear();
-            CharBuffer cb = CharBuffer.wrap(fragment);
-            CoderResult cr = encoder.encode(cb, intermediateBuffer, true);
-            intermediateBuffer.flip();
-            while (cr.isOverflow()) {
-                sendMessageBlocking(
-                        Constants.OPCODE_TEXT, intermediateBuffer, false);
-                intermediateBuffer.clear();
-                cr = encoder.encode(cb, intermediateBuffer, true);
-                intermediateBuffer.flip();
-            }
-            sendMessageBlocking(
-                    Constants.OPCODE_TEXT, intermediateBuffer, isLast);
+            byte opCode = Constants.OPCODE_BINARY;
+            boolean isLast = true;
+            sendMessage(opCode, data, isLast, completion);
         } finally {
-            // Make sure flag is reset before method exists
-            charToByteInProgress.set(false);
+            writeLock.unlock();
         }
     }
 
@@ -137,130 +152,181 @@ public abstract class WsRemoteEndpointBa
     @Override
     public void sendPartialBytes(ByteBuffer partialByte, boolean isLast)
             throws IOException {
-        sendMessageBlocking(Constants.OPCODE_BINARY, partialByte, isLast);
+        boolean locked = writeLock.tryLock();
+        if (!locked) {
+            throw new IllegalStateException(
+                    sm.getString("wsRemoteEndpoint.concurrentMessageSend"));
+        }
+        try {
+            byte opCode = Constants.OPCODE_BINARY;
+            FutureToSendHandler f2sh = new FutureToSendHandler();
+            sendMessage(opCode, partialByte, isLast, f2sh);
+            f2sh.get();
+        } catch (InterruptedException | ExecutionException e) {
+            throw new IOException(e);
+        } finally {
+            writeLock.unlock();
+        }
     }
 
 
     @Override
-    public void sendPing(ByteBuffer applicationData) throws IOException {
-        sendMessageBlocking(Constants.OPCODE_PING, applicationData, true);
+    public void sendPing(ByteBuffer applicationData) throws IOException,
+            IllegalArgumentException {
+        sendControlMessage(Constants.OPCODE_PING, applicationData);
     }
 
 
     @Override
-    public void sendPong(ByteBuffer applicationData) throws IOException {
-        sendMessageBlocking(Constants.OPCODE_PONG, applicationData, true);
+    public void sendPong(ByteBuffer applicationData) throws IOException,
+            IllegalArgumentException {
+        sendControlMessage(Constants.OPCODE_PONG, applicationData);
     }
 
 
     @Override
-    public Future<SendResult> sendBytesByFuture(ByteBuffer data) {
-        this.payload = data;
-        return sendMessageByFuture(Constants.OPCODE_BINARY, true);
+    public void sendString(String text) throws IOException {
+        Future<SendResult> f = sendStringByFuture(text);
+        try {
+            SendResult sr = f.get();
+            if (!sr.isOK()) {
+                if (sr.getException() == null) {
+                    throw new IOException();
+                } else {
+                    throw new IOException(sr.getException());
+                }
+            }
+        } catch (InterruptedException | ExecutionException e) {
+            throw new IOException(e);
+        }
     }
 
 
     @Override
-    public void sendBytesByCompletion(ByteBuffer data, SendHandler completion) {
-        this.payload = data;
-        sendMessageByCompletion(Constants.OPCODE_BINARY, true,
-                new WsCompletionHandler(this, completion, state, false));
+    public Future<SendResult> sendStringByFuture(String text) {
+        FutureToSendHandler f2sh = new FutureToSendHandler();
+        sendStringByCompletion(text, f2sh);
+        return f2sh;
     }
 
 
+    @Override
+    public void sendStringByCompletion(String text, SendHandler completion) {
+        boolean locked = writeLock.tryLock();
+        if (!locked) {
+            throw new IllegalStateException(
+                    sm.getString("wsRemoteEndpoint.concurrentMessageSend"));
+        }
+        try {
+            TextMessageSendHandler tmsh = new TextMessageSendHandler(
+                    completion, text, true, encoder, encoderBuffer, this);
+            tmsh.write();
+        } finally {
+            writeLock.unlock();
+        }
+    }
 
 
-
-
-
-    protected void sendMessageBlocking(byte opCode, ByteBuffer payload,
-            boolean isLast) throws IOException {
-
-        this.payload = payload;
-
-        Future<SendResult> f = sendMessageByFuture(opCode, isLast);
-        SendResult sr = null;
+    @Override
+    public void sendPartialString(String fragment, boolean isLast)
+            throws IOException {
+        boolean locked = writeLock.tryLock();
+        if (!locked) {
+            throw new IllegalStateException(
+                    sm.getString("wsRemoteEndpoint.concurrentMessageSend"));
+        }
         try {
-            sr = f.get();
+            FutureToSendHandler f2sh = new FutureToSendHandler();
+            TextMessageSendHandler tmsh = new TextMessageSendHandler(
+                    f2sh, fragment, isLast, encoder, encoderBuffer, this);
+            tmsh.write();
+            f2sh.get();
         } catch (InterruptedException | ExecutionException e) {
             throw new IOException(e);
-        }
-
-        if (!sr.isOK()) {
-            throw new IOException(sr.getException());
+        } finally {
+            writeLock.unlock();
         }
     }
 
 
-    private Future<SendResult> sendMessageByFuture(byte opCode,
-            boolean isLast) {
 
-        WsCompletionHandler wsCompletionHandler = new WsCompletionHandler(
-                this, state, opCode == Constants.OPCODE_CLOSE);
-        sendMessageByCompletion(opCode, isLast, wsCompletionHandler);
-        return wsCompletionHandler;
-    }
+    /**
+     * Sends a control message, blocking until the message is sent.
+     */
+    void sendControlMessage(byte opCode, ByteBuffer payload)
+            throws IOException{
 
+        // Close needs to be sent so disable batching. This will flush any
+        // messages in the buffer
+        if (opCode == Constants.OPCODE_CLOSE) {
+            setBatchingAllowed(false);
+        }
 
-    private void sendMessageByCompletion(byte opCode, boolean isLast,
-            WsCompletionHandler handler) {
+        writeLock.lock();
+        try {
+            if (state.isInProgress()) {
+                notInProgress.await();
+            }
+            FutureToSendHandler f2sh = new FutureToSendHandler();
+            sendMessage(opCode, payload, true, f2sh);
+            f2sh.get();
+        } catch (InterruptedException | ExecutionException e) {
+            throw new IOException(e);
+        } finally {
+            notInProgress.signal();
+            writeLock.unlock();
+        }
+    }
 
-        boolean isFirst = state.startMessage(opCode, isLast);
 
-        outputBuffer.clear();
-        byte first = 0;
+    private void sendMessage(byte opCode, ByteBuffer payload, boolean last,
+            SendHandler completion) {
 
-        if (isLast) {
-            // Set the fin bit
-            first = -128;
+        if (!writeLock.isHeldByCurrentThread()) {
+            // Coding problem
+            throw new IllegalStateException(
+                    "Must hold writeLock before calling this method");
         }
 
-        if (isFirst) {
-            // This is the first fragment of this message
-            first = (byte) (first + opCode);
-        }
-        // If not the first fragment, it is a continuation with opCode of zero
+        state.startMessage(opCode, last);
 
-        outputBuffer.put(first);
+        SendMessageSendHandler smsh =
+                new SendMessageSendHandler(state, completion, this);
 
-        byte masked = getMasked();
+        byte[] mask;
 
-        // Next write the mask && length length
-        if (payload.limit() < 126) {
-            outputBuffer.put((byte) (payload.limit() | masked));
-        } else if (payload.limit() < 65536) {
-            outputBuffer.put((byte) (126 | masked));
-            outputBuffer.put((byte) (payload.limit() >>> 8));
-            outputBuffer.put((byte) (payload.limit() & 0xFF));
+        if (isMasked()) {
+            mask = Util.generateMask();
         } else {
-            // Will never be more than 2^31-1
-            outputBuffer.put((byte) (127 | masked));
-            outputBuffer.put((byte) 0);
-            outputBuffer.put((byte) 0);
-            outputBuffer.put((byte) 0);
-            outputBuffer.put((byte) 0);
-            outputBuffer.put((byte) (payload.limit() >>> 24));
-            outputBuffer.put((byte) (payload.limit() >>> 16));
-            outputBuffer.put((byte) (payload.limit() >>> 8));
-            outputBuffer.put((byte) (payload.limit() & 0xFF));
-        }
-        if (masked != 0) {
-            // TODO Mask the data properly
-            outputBuffer.put((byte) 0);
-            outputBuffer.put((byte) 0);
-            outputBuffer.put((byte) 0);
-            outputBuffer.put((byte) 0);
+            mask = null;
         }
-        outputBuffer.flip();
 
-        sendMessage(handler);
+        headerBuffer.clear();
+        writeHeader(headerBuffer, opCode, payload, state.isFirst(), last,
+                isMasked(), mask);
+        headerBuffer.flip();
+
+        if (getBatchingAllowed() || isMasked()) {
+            // Need to write via output buffer
+            OutputBufferSendHandler obsh = new OutputBufferSendHandler(
+                    smsh, headerBuffer, payload, mask, outputBuffer,
+                    !getBatchingAllowed(), this);
+            obsh.write();
+        } else {
+            // Can write directly
+            doWrite(smsh, headerBuffer, payload);
+        }
     }
 
-    protected abstract byte getMasked();
-
-    protected abstract void sendMessage(WsCompletionHandler handler);
 
-    protected abstract void close();
+    private void endMessage() {
+        writeLock.lock();
+        try {
+            notInProgress.signal();
+        } finally {
+            writeLock.unlock();
+        }
+    }
 
 
 
@@ -276,146 +342,88 @@ public abstract class WsRemoteEndpointBa
         return null;
     }
 
-
     @Override
     public Writer getSendWriter() throws IOException {
         // TODO Auto-generated method stub
         return null;
     }
 
-
-    @Override
-    public Future<SendResult> sendStringByFuture(String text) {
-        // TODO Auto-generated method stub
-        return null;
-    }
-
-
     @Override
     public void sendObject(Object o) throws IOException, EncodeException {
         // TODO Auto-generated method stub
-    }
-
 
-    @Override
-    public void sendStringByCompletion(String text, SendHandler completion) {
-        // TODO Auto-generated method stub
     }
 
-
     @Override
     public Future<SendResult> sendObjectByFuture(Object obj) {
         // TODO Auto-generated method stub
         return null;
     }
 
-
     @Override
     public void sendObjectByCompletion(Object obj, SendHandler completion) {
         // TODO Auto-generated method stub
-    }
-
-
-
-
-
-
-
-
-    protected static class WsCompletionHandler implements Future<SendResult>,
-            CompletionHandler<Long,Void> {
-
-        private final WsRemoteEndpointBase wsRemoteEndpoint;
-        private final MessageSendStateMachine state;
-        private final SendHandler sendHandler;
-        private final boolean close;
-        private final CountDownLatch latch = new CountDownLatch(1);
 
-        private volatile SendResult result = null;
-
-
-        public WsCompletionHandler(WsRemoteEndpointBase wsRemoteEndpoint,
-                MessageSendStateMachine state, boolean close) {
-            this(wsRemoteEndpoint, null, state, close);
-        }
-
-
-        public WsCompletionHandler(WsRemoteEndpointBase wsRemoteEndpoint,
-                SendHandler sendHandler, MessageSendStateMachine state,
-                boolean close) {
-            this.wsRemoteEndpoint = wsRemoteEndpoint;
-            this.sendHandler = sendHandler;
-            this.state = state;
-            this.close = close;
-        }
+    }
 
 
-        // ------------------------------------------- CompletionHandler methods
 
-        @Override
-        public void completed(Long result, Void attachment) {
-            state.endMessage();
-            if (close) {
-                wsRemoteEndpoint.close();
-            }
-            this.result = new SendResult();
-            latch.countDown();
-            if (sendHandler != null) {
-                sendHandler.setResult(this.result);
-            }
-        }
 
 
-        @Override
-        public void failed(Throwable exc, Void attachment) {
-            state.endMessage();
-            if (close) {
-                wsRemoteEndpoint.close();
-            }
-            this.result = new SendResult(exc);
-            latch.countDown();
-            if (sendHandler != null) {
-                sendHandler.setResult(this.result);
-            }
-        }
+    protected abstract void doWrite(SendHandler handler, ByteBuffer... data);
+    protected abstract boolean isMasked();
+    protected abstract void close();
 
 
-        // ------------------------------------------------------ Future methods
+    private static void writeHeader(ByteBuffer headerBuffer, byte opCode,
+            ByteBuffer payload, boolean first, boolean last, boolean masked,
+            byte[] mask) {
 
-        @Override
-        public boolean cancel(boolean mayInterruptIfRunning) {
-            // Cancelling the task is not supported
-            return false;
-        }
+        byte b = 0;
 
-
-        @Override
-        public boolean isCancelled() {
-            // Cancelling the task is not supported
-            return false;
+        if (last) {
+            // Set the fin bit
+            b = -128;
         }
 
-
-        @Override
-        public boolean isDone() {
-            return latch.getCount() == 0;
+        if (first) {
+            // This is the first fragment of this message
+            b = (byte) (b + opCode);
         }
+        // If not the first fragment, it is a continuation with opCode of zero
 
+        headerBuffer.put(b);
 
-        @Override
-        public SendResult get() throws InterruptedException, ExecutionException {
-            latch.await();
-            return result;
+        if (masked) {
+            b = (byte) 0x80;
+        } else {
+            b = 0;
         }
 
-
-        @Override
-        public SendResult get(long timeout, TimeUnit unit)
-                throws InterruptedException, ExecutionException,
-                TimeoutException {
-
-            latch.await(timeout, unit);
-            return result;
+        // Next write the mask && length length
+        if (payload.limit() < 126) {
+            headerBuffer.put((byte) (payload.limit() | b));
+        } else if (payload.limit() < 65536) {
+            headerBuffer.put((byte) (126 | b));
+            headerBuffer.put((byte) (payload.limit() >>> 8));
+            headerBuffer.put((byte) (payload.limit() & 0xFF));
+        } else {
+            // Will never be more than 2^31-1
+            headerBuffer.put((byte) (127 | b));
+            headerBuffer.put((byte) 0);
+            headerBuffer.put((byte) 0);
+            headerBuffer.put((byte) 0);
+            headerBuffer.put((byte) 0);
+            headerBuffer.put((byte) (payload.limit() >>> 24));
+            headerBuffer.put((byte) (payload.limit() >>> 16));
+            headerBuffer.put((byte) (payload.limit() >>> 8));
+            headerBuffer.put((byte) (payload.limit() & 0xFF));
+        }
+        if (masked) {
+            headerBuffer.put(mask[0]);
+            headerBuffer.put(mask[1]);
+            headerBuffer.put(mask[2]);
+            headerBuffer.put(mask[3]);
         }
     }
 
@@ -425,11 +433,12 @@ public abstract class WsRemoteEndpointBa
         private boolean inProgress = false;
         private boolean fragmented = false;
         private boolean text = false;
+        private boolean first = false;
 
         private boolean nextFragmented = false;
         private boolean nextText = false;
 
-        public synchronized boolean startMessage(byte opCode, boolean isLast) {
+        public synchronized void startMessage(byte opCode, boolean isLast) {
 
             if (closed) {
                 throw new IllegalStateException(
@@ -451,7 +460,8 @@ public abstract class WsRemoteEndpointBa
                 if (opCode == Constants.OPCODE_CLOSE) {
                     closed = true;
                 }
-                return true;
+                first = true;
+                return;
             }
 
             boolean isText = Util.isText(opCode);
@@ -464,7 +474,7 @@ public abstract class WsRemoteEndpointBa
                 }
                 nextText = text;
                 nextFragmented = !isLast;
-                return false;
+                first = false;
             } else {
                 // Wasn't fragmented. Might be now
                 if (isLast) {
@@ -473,7 +483,7 @@ public abstract class WsRemoteEndpointBa
                     nextFragmented = true;
                     nextText = isText;
                 }
-                return true;
+                first = true;
             }
         }
 
@@ -482,5 +492,219 @@ public abstract class WsRemoteEndpointBa
             fragmented = nextFragmented;
             text = nextText;
         }
+
+        public synchronized boolean isInProgress() {
+            return inProgress;
+        }
+
+        public synchronized boolean isFirst() {
+            return first;
+        }
+    }
+
+
+    private static class TextMessageSendHandler implements SendHandler {
+
+        private final SendHandler handler;
+        private final CharBuffer message;
+        private final boolean isLast;
+        private final CharsetEncoder encoder;
+        private final ByteBuffer buffer;
+        private final WsRemoteEndpointBase endpoint;
+        private volatile boolean isDone = false;
+
+        public TextMessageSendHandler(SendHandler handler, String message,
+                boolean isLast, CharsetEncoder encoder,
+                ByteBuffer encoderBuffer, WsRemoteEndpointBase endpoint) {
+            this.handler = handler;
+            this.message = CharBuffer.wrap(message);
+            this.isLast = isLast;
+            this.encoder = encoder.reset();
+            this.buffer = encoderBuffer;
+            this.endpoint = endpoint;
+        }
+
+        public void write() {
+            buffer.clear();
+            CoderResult cr = encoder.encode(message, buffer, true);
+            if (cr.isError()) {
+                throw new IllegalArgumentException(cr.toString());
+            }
+            isDone = !cr.isOverflow();
+            buffer.flip();
+            endpoint.sendMessage(Constants.OPCODE_TEXT, buffer,
+                    isDone && isLast, this);
+        }
+
+        @Override
+        public void setResult(SendResult result) {
+            if (isDone || !result.isOK()) {
+                handler.setResult(result);
+            } else {
+                write();
+            }
+        }
+    }
+
+
+    /**
+     *  Wraps user provided {@link SendHandler} so that state is updated when
+     *  the message completes.
+     */
+    private static class SendMessageSendHandler implements SendHandler {
+
+        private final MessageSendStateMachine state;
+        private final SendHandler handler;
+        private final WsRemoteEndpointBase endpoint;
+
+        public SendMessageSendHandler(MessageSendStateMachine state,
+                SendHandler handler, WsRemoteEndpointBase endpoint) {
+            this.state = state;
+            this.handler = handler;
+            this.endpoint = endpoint;
+        }
+
+        @Override
+        public void setResult(SendResult result) {
+            state.endMessage();
+            if (state.closed) {
+                endpoint.close();
+            }
+            handler.setResult(result);
+            endpoint.endMessage();
+        }
+    }
+
+
+    /**
+     * Used to write data to the output buffer, flushing the buffer if it fills
+     * up.
+     */
+    private static class OutputBufferSendHandler implements SendHandler {
+
+        private final SendHandler handler;
+        private final ByteBuffer headerBuffer;
+        private final ByteBuffer payload;
+        private final byte[] mask;
+        private final ByteBuffer outputBuffer;
+        private volatile boolean flushRequired;
+        private final WsRemoteEndpointBase endpoint;
+        private int maskIndex = 0;
+
+        public OutputBufferSendHandler(SendHandler completion,
+                ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask,
+                ByteBuffer outputBuffer, boolean flushRequired,
+                WsRemoteEndpointBase endpoint) {
+            this.handler = completion;
+            this.headerBuffer = headerBuffer;
+            this.payload = payload;
+            this.mask = mask;
+            this.outputBuffer = outputBuffer;
+            this.flushRequired = flushRequired;
+            this.endpoint = endpoint;
+        }
+
+        public void write() {
+            // Write the header
+            while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) {
+                outputBuffer.put(headerBuffer.get());
+            }
+            if (headerBuffer.hasRemaining()) {
+                // Still more headers to write, need to flush
+                flushRequired = true;
+                outputBuffer.flip();
+                endpoint.doWrite(this, outputBuffer);
+                return;
+            }
+
+            // Write the payload
+            while (payload.hasRemaining() && outputBuffer.hasRemaining()) {
+                outputBuffer.put(
+                        (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF)));
+                if (maskIndex > 3) {
+                    maskIndex = 0;
+                }
+            }
+            if (payload.hasRemaining()) {
+                // Still more headers to write, need to flush
+                flushRequired = true;
+                outputBuffer.flip();
+                endpoint.doWrite(this, outputBuffer);
+                return;
+            }
+
+            if (flushRequired) {
+                outputBuffer.flip();
+                endpoint.doWrite(this, outputBuffer);
+                flushRequired = false;
+                return;
+            } else {
+                handler.setResult(new SendResult());
+            }
+        }
+
+        // ------------------------------------------------- SendHandler methods
+        @Override
+        public void setResult(SendResult result) {
+            outputBuffer.clear();
+            if (result.isOK()) {
+                write();
+            } else {
+                handler.setResult(result);
+            }
+        }
+    }
+
+    /**
+     * Converts a Future to a SendHandler.
+     */
+    private static class FutureToSendHandler
+            implements Future<SendResult>, SendHandler {
+
+        private final CountDownLatch latch = new CountDownLatch(1);
+        private volatile SendResult result = null;
+
+        // --------------------------------------------------------- SendHandler
+
+        @Override
+        public void setResult(SendResult result) {
+            this.result = result;
+            latch.countDown();
+        }
+
+
+        // -------------------------------------------------------------- Future
+
+        @Override
+        public boolean cancel(boolean mayInterruptIfRunning) {
+            // Cancelling the task is not supported
+            return false;
+        }
+
+        @Override
+        public boolean isCancelled() {
+            // Cancelling the task is not supported
+            return false;
+        }
+
+        @Override
+        public boolean isDone() {
+            return latch.getCount() == 0;
+        }
+
+        @Override
+        public SendResult get() throws InterruptedException,
+        ExecutionException {
+            latch.await();
+            return result;
+        }
+
+        @Override
+        public SendResult get(long timeout, TimeUnit unit)
+                throws InterruptedException, ExecutionException,
+                TimeoutException {
+            latch.await(timeout, unit);
+            return result;
+        }
     }
 }

Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java
URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java?rev=1443135&r1=1443134&r2=1443135&view=diff
==============================================================================
--- tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java (original)
+++ tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java Wed Feb  6 19:06:56 2013
@@ -19,8 +19,12 @@ package org.apache.tomcat.websocket;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.channels.AsynchronousSocketChannel;
+import java.nio.channels.CompletionHandler;
 import java.util.concurrent.TimeUnit;
 
+import javax.websocket.SendHandler;
+import javax.websocket.SendResult;
+
 public class WsRemoteEndpointClient extends WsRemoteEndpointBase {
 
     private final AsynchronousSocketChannel channel;
@@ -31,20 +35,22 @@ public class WsRemoteEndpointClient exte
 
 
     @Override
-    protected byte getMasked() {
-        return (byte) 0x80;
+    protected boolean isMasked() {
+        return true;
     }
 
 
     @Override
-    protected void sendMessage(WsCompletionHandler handler) {
+    protected void doWrite(SendHandler handler, ByteBuffer... data) {
         long timeout = getAsyncSendTimeout();
         if (timeout < 1) {
             timeout = Long.MAX_VALUE;
 
         }
-        channel.write(new ByteBuffer[] {outputBuffer, payload}, 0, 2,
-                getAsyncSendTimeout(), TimeUnit.MILLISECONDS, null, handler);
+        SendHandlerToCompletionHandler sh2ch =
+                new SendHandlerToCompletionHandler(handler);
+        channel.write(data, 0, data.length, getAsyncSendTimeout(),
+                TimeUnit.MILLISECONDS, null, sh2ch);
     }
 
     @Override
@@ -55,4 +61,25 @@ public class WsRemoteEndpointClient exte
             // Ignore
         }
     }
+
+
+    private static class SendHandlerToCompletionHandler
+            implements CompletionHandler<Long,Void> {
+
+        private SendHandler handler;
+
+        public SendHandlerToCompletionHandler(SendHandler handler) {
+            this.handler = handler;
+        }
+
+        @Override
+        public void completed(Long result, Void attachment) {
+            handler.setResult(new SendResult());
+        }
+
+        @Override
+        public void failed(Throwable exc, Void attachment) {
+            handler.setResult(new SendResult(exc));
+        }
+    }
 }

Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java
URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java?rev=1443135&r1=1443134&r2=1443135&view=diff
==============================================================================
--- tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java (original)
+++ tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java Wed Feb  6 19:06:56 2013
@@ -254,8 +254,8 @@ public class WsSession implements Sessio
             }
             msg.flip();
             try {
-                wsRemoteEndpoint.sendMessageBlocking(
-                        Constants.OPCODE_CLOSE, msg, true);
+                wsRemoteEndpoint.sendControlMessage(
+                        Constants.OPCODE_CLOSE, msg);
             } catch (IOException ioe) {
                 // Unable to send close message.
                 // TODO - Ignore?

Modified: tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java
URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java?rev=1443135&r1=1443134&r2=1443135&view=diff
==============================================================================
--- tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java (original)
+++ tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java Wed Feb  6 19:06:56 2013
@@ -18,8 +18,11 @@ package org.apache.tomcat.websocket.serv
 
 import java.io.IOException;
 import java.net.SocketTimeoutException;
+import java.nio.ByteBuffer;
 
 import javax.servlet.ServletOutputStream;
+import javax.websocket.SendHandler;
+import javax.websocket.SendResult;
 
 import org.apache.juli.logging.Log;
 import org.apache.juli.logging.LogFactory;
@@ -40,12 +43,11 @@ public class WsRemoteEndpointServer exte
 
     private final ServletOutputStream sos;
     private final WsTimeout wsTimeout;
-    private volatile WsCompletionHandler handler = null;
+    private volatile SendHandler handler = null;
+    private volatile ByteBuffer[] buffers = null;
+
     private volatile long timeoutExpiry = -1;
     private volatile boolean close;
-    private volatile Long size = null;
-    private volatile boolean headerWritten = false;
-    private volatile boolean payloadWritten = false;
 
 
     public WsRemoteEndpointServer(ServletOutputStream sos,
@@ -56,50 +58,59 @@ public class WsRemoteEndpointServer exte
 
 
     @Override
-    protected byte getMasked() {
-        // Messages from the server are not masked
-        return 0;
+    protected final boolean isMasked() {
+        return false;
     }
 
 
     @Override
-    protected void sendMessage(WsCompletionHandler handler) {
+    protected void doWrite(SendHandler handler, ByteBuffer... buffers) {
         this.handler = handler;
+        this.buffers = buffers;
         onWritePossible();
     }
 
 
     public void onWritePossible() {
+        boolean complete = true;
         try {
             // If this is false there will be a call back when it is true
             while (sos.canWrite()) {
-                if (!headerWritten) {
-                    headerWritten = true;
-                    size = Long.valueOf(
-                            outputBuffer.remaining() + payload.remaining());
-                    sos.write(outputBuffer.array(), outputBuffer.arrayOffset(),
-                            outputBuffer.limit());
-                } else if (!payloadWritten) {
-                    payloadWritten = true;
-                    sos.write(payload.array(), payload.arrayOffset(),
-                            payload.limit());
-                } else {
+                complete = true;
+                for (ByteBuffer buffer : buffers) {
+                    if (buffer.hasRemaining()) {
+                        complete = false;
+                        sos.write(buffer.array(), buffer.arrayOffset(),
+                                buffer.limit());
+                        buffer.position(buffer.limit());
+                        break;
+                    }
+                }
+                if (complete) {
                     wsTimeout.unregister(this);
                     if (close) {
                         close();
                     }
-                    handler.completed(size, null);
-                    nextWrite();
+                    // Setting the result marks this (partial) message as
+                    // complete which means the next one may be sent which
+                    // could update the value of the handler. Therefore, keep a
+                    // local copy before signalling the end of the (partial)
+                    // message.
+                    SendHandler sh = handler;
+                    handler = null;
+                    sh.setResult(new SendResult());
                     break;
                 }
             }
+
         } catch (IOException ioe) {
             wsTimeout.unregister(this);
             close();
-            handler.failed(ioe, null);
-            nextWrite();
+            SendHandler sh = handler;
+            handler = null;
+            sh.setResult(new SendResult(ioe));
         }
-        if (handler != null) {
+        if (!complete) {
             // Async write is in progress
 
             long timeout = getAsyncSendTimeout();
@@ -132,15 +143,7 @@ public class WsRemoteEndpointServer exte
 
     protected void onTimeout() {
         close();
-        handler.failed(new SocketTimeoutException(), null);
-        nextWrite();
-    }
-
-
-    private void nextWrite() {
+        handler.setResult(new SendResult(new SocketTimeoutException()));
         handler = null;
-        size = null;
-        headerWritten = false;
-        payloadWritten = false;
     }
 }

Modified: tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java
URL: http://svn.apache.org/viewvc/tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java?rev=1443135&r1=1443134&r2=1443135&view=diff
==============================================================================
--- tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java (original)
+++ tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java Wed Feb  6 19:06:56 2013
@@ -160,7 +160,7 @@ public class TestWsWebSocketContainer ex
 
     @Test
     public void testSmallBinaryBufferClientTextMessage() throws Exception {
-        doBufferTest(false, false, true, false);
+        doBufferTest(false, false, true, true);
     }
 
 
@@ -172,7 +172,7 @@ public class TestWsWebSocketContainer ex
 
     @Test
     public void testSmallBinaryBufferServerTextMessage() throws Exception {
-        doBufferTest(false, true, true, false);
+        doBufferTest(false, true, true, true);
     }
 
 
@@ -382,7 +382,6 @@ public class TestWsWebSocketContainer ex
         // Check nothing really bad happened
         Assert.assertNull(ConstantTxEndpoint.getException());
 
-        System.out.println(ConstantTxEndpoint.getTimeout());
         // Check correct time passed
         Assert.assertTrue(ConstantTxEndpoint.getTimeout() >= TIMEOUT_MS);
 



---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@tomcat.apache.org
For additional commands, e-mail: dev-help@tomcat.apache.org