You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mina.apache.org by el...@apache.org on 2009/12/10 01:04:19 UTC

svn commit: r889025 - /mina/trunk/core/src/main/java/org/apache/mina/filter/ssl/SslHandler.java

Author: elecharny
Date: Thu Dec 10 00:04:19 2009
New Revision: 889025

URL: http://svn.apache.org/viewvc?rev=889025&view=rev
Log:
o Fix for DIRMINA-650 proposed by Cedric Lucas applied (with a slight cosmetic modification)
o Some other simplifications

Modified:
    mina/trunk/core/src/main/java/org/apache/mina/filter/ssl/SslHandler.java

Modified: mina/trunk/core/src/main/java/org/apache/mina/filter/ssl/SslHandler.java
URL: http://svn.apache.org/viewvc/mina/trunk/core/src/main/java/org/apache/mina/filter/ssl/SslHandler.java?rev=889025&r1=889024&r2=889025&view=diff
==============================================================================
--- mina/trunk/core/src/main/java/org/apache/mina/filter/ssl/SslHandler.java (original)
+++ mina/trunk/core/src/main/java/org/apache/mina/filter/ssl/SslHandler.java Thu Dec 10 00:04:19 2009
@@ -29,6 +29,8 @@
 import javax.net.ssl.SSLEngineResult;
 import javax.net.ssl.SSLException;
 import javax.net.ssl.SSLHandshakeException;
+import javax.net.ssl.SSLEngineResult.HandshakeStatus;
+import javax.net.ssl.SSLEngineResult.Status;
 
 import org.apache.mina.core.buffer.IoBuffer;
 import org.apache.mina.core.filterchain.IoFilterEvent;
@@ -89,14 +91,16 @@
     private boolean handshakeComplete;
     private boolean writingEncryptedData;
 
+    /** A flag used when the Handlshake is finished */
+    private static final boolean HANDSHAKE_FINISHED = true;
+
     /**
      * Constuctor.
      *
      * @param sslc
      * @throws SSLException
      */
-    public SslHandler(SslFilter parent, SSLContext sslContext, IoSession session)
-            throws SSLException {
+    public SslHandler(SslFilter parent, SSLContext sslContext, IoSession session) throws SSLException {
         this.parent = parent;
         this.session = session;
         this.sslContext = sslContext;
@@ -114,16 +118,15 @@
             return;
         }
 
-        InetSocketAddress peer = (InetSocketAddress) session
-                .getAttribute(SslFilter.PEER_ADDRESS);
-        
+        InetSocketAddress peer = (InetSocketAddress) session.getAttribute(SslFilter.PEER_ADDRESS);
+
         // Create the SSL engine here
         if (peer == null) {
             sslEngine = sslContext.createSSLEngine();
         } else {
             sslEngine = sslContext.createSSLEngine(peer.getHostName(), peer.getPort());
         }
-        
+
         // Initialize the engine in client mode if necessary
         sslEngine.setUseClientMode(parent.isUseClientMode());
 
@@ -146,8 +149,7 @@
 
         // TODO : we may not need to call this method...
         sslEngine.beginHandshake();
-        
-        
+
         handshakeStatus = sslEngine.getHandshakeStatus();
 
         handshakeComplete = false;
@@ -167,11 +169,9 @@
         try {
             sslEngine.closeInbound();
         } catch (SSLException e) {
-            LOGGER.debug(
-                    "Unexpected exception from SSLEngine.closeInbound().", e);
+            LOGGER.debug("Unexpected exception from SSLEngine.closeInbound().", e);
         }
 
-
         if (outNetBuffer != null) {
             outNetBuffer.capacity(sslEngine.getSession().getPacketBufferSize());
         } else {
@@ -235,18 +235,15 @@
         return handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP && !isInboundDone();
     }
 
-    public void schedulePreHandshakeWriteRequest(NextFilter nextFilter,
-                                                 WriteRequest writeRequest) {
-        preHandshakeEventQueue.add(new IoFilterEvent(nextFilter,
-                IoEventType.WRITE, session, writeRequest));
+    public void schedulePreHandshakeWriteRequest(NextFilter nextFilter, WriteRequest writeRequest) {
+        preHandshakeEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.WRITE, session, writeRequest));
     }
 
     public void flushPreHandshakeEvents() throws SSLException {
         IoFilterEvent scheduledWrite;
 
         while ((scheduledWrite = preHandshakeEventQueue.poll()) != null) {
-            parent.filterWrite(scheduledWrite.getNextFilter(), session,
-                    (WriteRequest) scheduledWrite.getParameter());
+            parent.filterWrite(scheduledWrite.getNextFilter(), session, (WriteRequest) scheduledWrite.getParameter());
         }
     }
 
@@ -280,13 +277,15 @@
     }
 
     /**
-     * Call when data read from net. Will perform inial hanshake or decrypt provided
-     * Buffer.
-     * Decrytpted data reurned by getAppBuffer(), if any.
-     *
-     * @param buf        buffer to decrypt
-     * @param nextFilter Next filter in chain
-     * @throws SSLException on errors
+     * Call when data read from net. Will perform inial hanshake or decrypt
+     * provided Buffer. Decrytpted data reurned by getAppBuffer(), if any.
+     * 
+     * @param buf
+     *            buffer to decrypt
+     * @param nextFilter
+     *            Next filter in chain
+     * @throws SSLException
+     *             on errors
      */
     public void messageReceived(NextFilter nextFilter, ByteBuffer buf) throws SSLException {
         // append buf to inNetBuffer
@@ -302,8 +301,9 @@
         }
 
         if (isInboundDone()) {
-            // Rewind the MINA buffer if not all data is processed and inbound is finished.
-            int inNetBufferPosition = inNetBuffer == null? 0 : inNetBuffer.position();
+            // Rewind the MINA buffer if not all data is processed and inbound
+            // is finished.
+            int inNetBufferPosition = inNetBuffer == null ? 0 : inNetBuffer.position();
             buf.position(buf.position() - inNetBufferPosition);
             inNetBuffer = null;
         }
@@ -311,7 +311,7 @@
 
     /**
      * Get decrypted application data.
-     *
+     * 
      * @return buffer with data
      */
     public IoBuffer fetchAppBuffer() {
@@ -322,7 +322,7 @@
 
     /**
      * Get encrypted data to be sent.
-     *
+     * 
      * @return buffer with data
      */
     public IoBuffer fetchOutNetBuffer() {
@@ -337,9 +337,11 @@
 
     /**
      * Encrypt provided buffer. Encrypted data returned by getOutNetBuffer().
-     *
-     * @param src data to encrypt
-     * @throws SSLException on errors
+     * 
+     * @param src
+     *            data to encrypt
+     * @throws SSLException
+     *             on errors
      */
     public void encrypt(ByteBuffer src) throws SSLException {
         if (!handshakeComplete) {
@@ -367,8 +369,7 @@
                 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
                 outNetBuffer.limit(outNetBuffer.capacity());
             } else {
-                throw new SSLException("SSLEngine error during encrypt: "
-                        + result.getStatus() + " src: " + src
+                throw new SSLException("SSLEngine error during encrypt: " + result.getStatus() + " src: " + src
                         + "outNetBuffer: " + outNetBuffer);
             }
         }
@@ -378,10 +379,11 @@
 
     /**
      * Start SSL shutdown process.
-     *
-     * @return <tt>true</tt> if shutdown process is started.
-     *         <tt>false</tt> if shutdown process is already finished.
-     * @throws SSLException on errors
+     * 
+     * @return <tt>true</tt> if shutdown process is started. <tt>false</tt> if
+     *         shutdown process is already finished.
+     * @throws SSLException
+     *             on errors
      */
     public boolean closeOutbound() throws SSLException {
         if (sslEngine == null || sslEngine.isOutboundDone()) {
@@ -411,7 +413,7 @@
 
     /**
      * Decrypt in net buffer. Result is stored in app buffer.
-     *
+     * 
      * @throws SSLException
      */
     private void decrypt(NextFilter nextFilter) throws SSLException {
@@ -427,24 +429,20 @@
      * @param res
      * @throws SSLException
      */
-    private void checkStatus(SSLEngineResult res)
-            throws SSLException {
+    private void checkStatus(SSLEngineResult res) throws SSLException {
 
         SSLEngineResult.Status status = res.getStatus();
 
         /*
-        * The status may be:
-        * OK - Normal operation
-        * OVERFLOW - Should never happen since the application buffer is
-        *      sized to hold the maximum packet size.
-        * UNDERFLOW - Need to read more data from the socket. It's normal.
-        * CLOSED - The other peer closed the socket. Also normal.
-        */
-        if (status != SSLEngineResult.Status.OK
-                && status != SSLEngineResult.Status.CLOSED
-                && status != SSLEngineResult.Status.BUFFER_UNDERFLOW) {
-            throw new SSLException("SSLEngine error during decrypt: " + status
-                    + " inNetBuffer: " + inNetBuffer + "appBuffer: "
+         * The status may be: 
+         * OK - Normal operation 
+         * OVERFLOW - Should never happen since the application buffer is sized to hold the maximum
+         * packet size. 
+         * UNDERFLOW - Need to read more data from the socket. It's normal. 
+         * CLOSED - The other peer closed the socket. Also normal.
+         */
+        if (status == SSLEngineResult.Status.BUFFER_OVERFLOW) {
+            throw new SSLException("SSLEngine error during decrypt: " + status + " inNetBuffer: " + inNetBuffer + "appBuffer: "
                     + appBuffer);
         }
     }
@@ -455,77 +453,72 @@
     public void handshake(NextFilter nextFilter) throws SSLException {
         for (;;) {
             switch (handshakeStatus) {
-                case FINISHED :
-                    session.setAttribute(
-                            SslFilter.SSL_SESSION, sslEngine.getSession());
-                    handshakeComplete = true;
-                    
-                    if (!initialHandshakeComplete
-                            && session.containsAttribute(SslFilter.USE_NOTIFICATION)) {
-                        // SESSION_SECURED is fired only when it's the first handshake.
-                        // (i.e. renegotiation shouldn't trigger SESSION_SECURED.)
-                        initialHandshakeComplete = true;
-                        scheduleMessageReceived(nextFilter,
-                                SslFilter.SESSION_SECURED);
-                    }
-                    
+            case FINISHED:
+                session.setAttribute(SslFilter.SSL_SESSION, sslEngine.getSession());
+                handshakeComplete = true;
+
+                if (!initialHandshakeComplete && session.containsAttribute(SslFilter.USE_NOTIFICATION)) {
+                    // SESSION_SECURED is fired only when it's the first
+                    // handshake.
+                    // (i.e. renegotiation shouldn't trigger SESSION_SECURED.)
+                    initialHandshakeComplete = true;
+                    scheduleMessageReceived(nextFilter, SslFilter.SESSION_SECURED);
+                }
+
+                return;
+
+            case NEED_TASK:
+                handshakeStatus = doTasks();
+                break;
+
+            case NEED_UNWRAP:
+                // we need more data read
+                SSLEngineResult.Status status = unwrapHandshake(nextFilter);
+
+                if (status == SSLEngineResult.Status.BUFFER_UNDERFLOW
+                        && handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED || isInboundDone()) {
+                    // We need more data or the session is closed
                     return;
-                    
-                case NEED_TASK :
-                    handshakeStatus = doTasks();
-                    break;
-                    
-                case NEED_UNWRAP :
-                    // we need more data read
-                    SSLEngineResult.Status status = unwrapHandshake(nextFilter);
-                    
-                    if (status == SSLEngineResult.Status.BUFFER_UNDERFLOW &&
-                            handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED ||
-                            isInboundDone()) {
-                        // We need more data or the session is closed
-                        return;
-                    }
-                    
-                    break;
+                }
 
-                case NEED_WRAP :
-                    // First make sure that the out buffer is completely empty. Since we
-                    // cannot call wrap with data left on the buffer
-                    if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
-                        return;
-                    }
+                break;
+
+            case NEED_WRAP:
+                // First make sure that the out buffer is completely empty.
+                // Since we
+                // cannot call wrap with data left on the buffer
+                if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
+                    return;
+                }
 
-                    SSLEngineResult result;
-                    createOutNetBuffer(0);
-                    
-                    for (;;) {
-                        result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
-                        if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
-                            outNetBuffer.capacity(outNetBuffer.capacity() << 1);
-                            outNetBuffer.limit(outNetBuffer.capacity());
-                        } else {
-                            break;
-                        }
+                SSLEngineResult result;
+                createOutNetBuffer(0);
+
+                for (;;) {
+                    result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
+                    if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
+                        outNetBuffer.capacity(outNetBuffer.capacity() << 1);
+                        outNetBuffer.limit(outNetBuffer.capacity());
+                    } else {
+                        break;
                     }
+                }
 
-                    outNetBuffer.flip();
-                    handshakeStatus = result.getHandshakeStatus();
-                    writeNetBuffer(nextFilter);
-                    break;
-            
-                default :
-                    throw new IllegalStateException("Invalid Handshaking State"
-                            + handshakeStatus);
+                outNetBuffer.flip();
+                handshakeStatus = result.getHandshakeStatus();
+                writeNetBuffer(nextFilter);
+                break;
+
+            default:
+                throw new IllegalStateException("Invalid Handshaking State" + handshakeStatus);
             }
         }
     }
 
     private void createOutNetBuffer(int expectedRemaining) {
         // SSLEngine requires us to allocate unnecessarily big buffer
-        // even for small data.  *Shrug*
-        int capacity = Math.max(
-                expectedRemaining,
-                sslEngine.getSession().getPacketBufferSize());
+        // even for small data. *Shrug*
+        int capacity = Math.max(expectedRemaining, sslEngine.getSession().getPacketBufferSize());
 
         if (outNetBuffer != null) {
             outNetBuffer.capacity(capacity);
@@ -534,8 +527,7 @@
         }
     }
 
-    public WriteFuture writeNetBuffer(NextFilter nextFilter)
-            throws SSLException {
+    public WriteFuture writeNetBuffer(NextFilter nextFilter) throws SSLException {
         // Check if any net data needed to be writen
         if (outNetBuffer == null || !outNetBuffer.hasRemaining()) {
             // no; bail out
@@ -552,16 +544,14 @@
         try {
             IoBuffer writeBuffer = fetchOutNetBuffer();
             writeFuture = new DefaultWriteFuture(session);
-            parent.filterWrite(nextFilter, session, new DefaultWriteRequest(
-                    writeBuffer, writeFuture));
+            parent.filterWrite(nextFilter, session, new DefaultWriteRequest(writeBuffer, writeFuture));
 
             // loop while more writes required to complete handshake
             while (needToCompleteHandshake()) {
                 try {
                     handshake(nextFilter);
                 } catch (SSLException ssle) {
-                    SSLException newSsle = new SSLHandshakeException(
-                            "SSL handshake failed.");
+                    SSLException newSsle = new SSLHandshakeException("SSL handshake failed.");
                     newSsle.initCause(ssle);
                     throw newSsle;
                 }
@@ -569,8 +559,7 @@
                 IoBuffer outNetBuffer = fetchOutNetBuffer();
                 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
                     writeFuture = new DefaultWriteFuture(session);
-                    parent.filterWrite(nextFilter, session,
-                            new DefaultWriteRequest(outNetBuffer, writeFuture));
+                    parent.filterWrite(nextFilter, session, new DefaultWriteRequest(outNetBuffer, writeFuture));
                 }
             }
         } finally {
@@ -590,7 +579,7 @@
             return;
         }
 
-        SSLEngineResult res = unwrap0();
+        SSLEngineResult res = unwrap0(!HANDSHAKE_FINISHED);
 
         // prepare to be written again
         if (inNetBuffer.hasRemaining()) {
@@ -615,17 +604,17 @@
             return SSLEngineResult.Status.BUFFER_UNDERFLOW;
         }
 
-        SSLEngineResult res = unwrap0();
+        SSLEngineResult res = unwrap0(!HANDSHAKE_FINISHED);
         handshakeStatus = res.getHandshakeStatus();
 
         checkStatus(res);
 
-        // If handshake finished, no data was produced, and the status is still ok,
+        // If handshake finished, no data was produced, and the status is still
+        // ok,
         // try to unwrap more
-        if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED
-                && res.getStatus() == SSLEngineResult.Status.OK
+        if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED && res.getStatus() == SSLEngineResult.Status.OK
                 && inNetBuffer.hasRemaining()) {
-            res = unwrap0();
+            res = unwrap0(HANDSHAKE_FINISHED);
 
             // prepare to be written again
             if (inNetBuffer.hasRemaining()) {
@@ -647,10 +636,8 @@
         return res.getStatus();
     }
 
-    private void renegotiateIfNeeded(NextFilter nextFilter, SSLEngineResult res)
-            throws SSLException {
-        if (res.getStatus() != SSLEngineResult.Status.CLOSED
-                && res.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW
+    private void renegotiateIfNeeded(NextFilter nextFilter, SSLEngineResult res) throws SSLException {
+        if (res.getStatus() != SSLEngineResult.Status.CLOSED && res.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW
                 && res.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
             // Renegotiation required.
             handshakeComplete = false;
@@ -659,7 +646,7 @@
         }
     }
 
-    private SSLEngineResult unwrap0() throws SSLException {
+    private SSLEngineResult unwrap0(boolean finished) throws SSLException {
         if (appBuffer == null) {
             appBuffer = IoBuffer.allocate(inNetBuffer.remaining());
         } else {
@@ -667,16 +654,22 @@
         }
 
         SSLEngineResult res;
+
+        Status status = null;
+        HandshakeStatus handshakeStatus = null;
+
         do {
             res = sslEngine.unwrap(inNetBuffer.buf(), appBuffer.buf());
-            if (res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
+            status = res.getStatus();
+            handshakeStatus = res.getHandshakeStatus();
+
+            if (status == SSLEngineResult.Status.BUFFER_OVERFLOW) {
                 appBuffer.capacity(appBuffer.capacity() << 1);
                 appBuffer.limit(appBuffer.capacity());
                 continue;
             }
-        } while ((res.getStatus() == SSLEngineResult.Status.OK || res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) &&
-                 (handshakeComplete && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING ||
-                  res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP));
+        } while (((status == SSLEngineResult.Status.OK) || (status == SSLEngineResult.Status.BUFFER_OVERFLOW))
+                && (((finished || handshakeComplete) && (handshakeStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING)) || (handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP)));
 
         return res;
     }
@@ -686,22 +679,24 @@
      */
     private SSLEngineResult.HandshakeStatus doTasks() {
         /*
-         * We could run this in a separate thread, but I don't see the need
-         * for this when used from SSLFilter. Use thread filters in MINA instead?
+         * We could run this in a separate thread, but I don't see the need for
+         * this when used from SSLFilter. Use thread filters in MINA instead?
          */
         Runnable runnable;
         while ((runnable = sslEngine.getDelegatedTask()) != null) {
-            // TODO : we may have to use a thread pool here to improve the performances
+            // TODO : we may have to use a thread pool here to improve the
+            // performances
             runnable.run();
         }
         return sslEngine.getHandshakeStatus();
     }
 
     /**
-     * Creates a new MINA buffer that is a deep copy of the remaining bytes
-     * in the given buffer (between index buf.position() and buf.limit())
-     *
-     * @param src the buffer to copy
+     * Creates a new MINA buffer that is a deep copy of the remaining bytes in
+     * the given buffer (between index buf.position() and buf.limit())
+     * 
+     * @param src
+     *            the buffer to copy
      * @return the new buffer, ready to read from
      */
     public static IoBuffer copy(ByteBuffer src) {